959a391334
Some checks failed
publish docs / publish-docs (push) Has been cancelled
release-please / release-please (push) Has been cancelled
tests / setup (push) Has been cancelled
tests / ${{ matrix.quality-command }} (black) (push) Has been cancelled
tests / ${{ matrix.quality-command }} (mypy) (push) Has been cancelled
tests / ${{ matrix.quality-command }} (ruff) (push) Has been cancelled
tests / test (push) Has been cancelled
tests / all_checks_passed (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
628 lines
22 KiB
Python
628 lines
22 KiB
Python
from typing import Any, Literal
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from private_gpt.settings.settings_loader import load_active_settings
|
|
|
|
|
|
class CorsSettings(BaseModel):
|
|
"""CORS configuration.
|
|
|
|
For more details on the CORS configuration, see:
|
|
# * https://fastapi.tiangolo.com/tutorial/cors/
|
|
# * https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
|
|
"""
|
|
|
|
enabled: bool = Field(
|
|
description="Flag indicating if CORS headers are set or not."
|
|
"If set to True, the CORS headers will be set to allow all origins, methods and headers.",
|
|
default=False,
|
|
)
|
|
allow_credentials: bool = Field(
|
|
description="Indicate that cookies should be supported for cross-origin requests",
|
|
default=False,
|
|
)
|
|
allow_origins: list[str] = Field(
|
|
description="A list of origins that should be permitted to make cross-origin requests.",
|
|
default=[],
|
|
)
|
|
allow_origin_regex: list[str] = Field(
|
|
description="A regex string to match against origins that should be permitted to make cross-origin requests.",
|
|
default=None,
|
|
)
|
|
allow_methods: list[str] = Field(
|
|
description="A list of HTTP methods that should be allowed for cross-origin requests.",
|
|
default=[
|
|
"GET",
|
|
],
|
|
)
|
|
allow_headers: list[str] = Field(
|
|
description="A list of HTTP request headers that should be supported for cross-origin requests.",
|
|
default=[],
|
|
)
|
|
|
|
|
|
class AuthSettings(BaseModel):
|
|
"""Authentication configuration.
|
|
|
|
The implementation of the authentication strategy must
|
|
"""
|
|
|
|
enabled: bool = Field(
|
|
description="Flag indicating if authentication is enabled or not.",
|
|
default=False,
|
|
)
|
|
secret: str = Field(
|
|
description="The secret to be used for authentication. "
|
|
"It can be any non-blank string. For HTTP basic authentication, "
|
|
"this value should be the whole 'Authorization' header that is expected"
|
|
)
|
|
|
|
|
|
class IngestionSettings(BaseModel):
|
|
"""Ingestion configuration.
|
|
|
|
This configuration is used to control the ingestion of data into the system
|
|
using non-server methods. This is useful for local development and testing;
|
|
or to ingest in bulk from a folder.
|
|
|
|
Please note that this configuration is not secure and should be used in
|
|
a controlled environment only (setting right permissions, etc.).
|
|
"""
|
|
|
|
enabled: bool = Field(
|
|
description="Flag indicating if local ingestion is enabled or not.",
|
|
default=False,
|
|
)
|
|
allow_ingest_from: list[str] = Field(
|
|
description="A list of folders that should be permitted to make ingest requests.",
|
|
default=[],
|
|
)
|
|
|
|
|
|
class ServerSettings(BaseModel):
|
|
env_name: str = Field(
|
|
description="Name of the environment (prod, staging, local...)"
|
|
)
|
|
port: int = Field(description="Port of PrivateGPT FastAPI server, defaults to 8001")
|
|
cors: CorsSettings = Field(
|
|
description="CORS configuration", default=CorsSettings(enabled=False)
|
|
)
|
|
auth: AuthSettings = Field(
|
|
description="Authentication configuration",
|
|
default_factory=lambda: AuthSettings(enabled=False, secret="secret-key"),
|
|
)
|
|
|
|
|
|
class DataSettings(BaseModel):
|
|
local_ingestion: IngestionSettings = Field(
|
|
description="Ingestion configuration",
|
|
default_factory=lambda: IngestionSettings(allow_ingest_from=["*"]),
|
|
)
|
|
local_data_folder: str = Field(
|
|
description="Path to local storage."
|
|
"It will be treated as an absolute path if it starts with /"
|
|
)
|
|
|
|
|
|
class LLMSettings(BaseModel):
|
|
mode: Literal[
|
|
"llamacpp",
|
|
"openai",
|
|
"openailike",
|
|
"azopenai",
|
|
"sagemaker",
|
|
"mock",
|
|
"ollama",
|
|
"gemini",
|
|
]
|
|
max_new_tokens: int = Field(
|
|
256,
|
|
description="The maximum number of token that the LLM is authorized to generate in one completion.",
|
|
)
|
|
context_window: int = Field(
|
|
3900,
|
|
description="The maximum number of context tokens for the model.",
|
|
)
|
|
tokenizer: str = Field(
|
|
None,
|
|
description="The model id of a predefined tokenizer hosted inside a model repo on "
|
|
"huggingface.co. Valid model ids can be located at the root-level, like "
|
|
"`bert-base-uncased`, or namespaced under a user or organization name, "
|
|
"like `HuggingFaceH4/zephyr-7b-beta`. If not set, will load a tokenizer matching "
|
|
"gpt-3.5-turbo LLM.",
|
|
)
|
|
temperature: float = Field(
|
|
0.1,
|
|
description="The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual.",
|
|
)
|
|
prompt_style: Literal[
|
|
"default", "llama2", "llama3", "tag", "mistral", "chatml"
|
|
] = Field(
|
|
"llama2",
|
|
description=(
|
|
"The prompt style to use for the chat engine. "
|
|
"If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
|
|
"If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n"
|
|
"If `llama3` - use the llama3 prompt style from the llama_index."
|
|
"If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
|
|
"If `mistral` - use the `mistral prompt style. It shoudl look like <s>[INST] {System Prompt} [/INST]</s>[INST] { UserInstructions } [/INST]"
|
|
"`llama2` is the historic behaviour. `default` might work better with your custom models."
|
|
),
|
|
)
|
|
|
|
|
|
class VectorstoreSettings(BaseModel):
|
|
database: Literal["chroma", "qdrant", "postgres", "clickhouse", "milvus"]
|
|
|
|
|
|
class NodeStoreSettings(BaseModel):
|
|
database: Literal["simple", "postgres"]
|
|
|
|
|
|
class LlamaCPPSettings(BaseModel):
|
|
llm_hf_repo_id: str
|
|
llm_hf_model_file: str
|
|
tfs_z: float = Field(
|
|
1.0,
|
|
description="Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting.",
|
|
)
|
|
top_k: int = Field(
|
|
40,
|
|
description="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)",
|
|
)
|
|
top_p: float = Field(
|
|
0.9,
|
|
description="Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)",
|
|
)
|
|
repeat_penalty: float = Field(
|
|
1.1,
|
|
description="Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)",
|
|
)
|
|
|
|
|
|
class HuggingFaceSettings(BaseModel):
|
|
embedding_hf_model_name: str = Field(
|
|
description="Name of the HuggingFace model to use for embeddings"
|
|
)
|
|
access_token: str = Field(
|
|
None,
|
|
description="Huggingface access token, required to download some models",
|
|
)
|
|
trust_remote_code: bool = Field(
|
|
False,
|
|
description="If set to True, the code from the remote model will be trusted and executed.",
|
|
)
|
|
|
|
|
|
class EmbeddingSettings(BaseModel):
|
|
mode: Literal[
|
|
"huggingface", "openai", "azopenai", "sagemaker", "ollama", "mock", "gemini"
|
|
]
|
|
ingest_mode: Literal["simple", "batch", "parallel", "pipeline"] = Field(
|
|
"simple",
|
|
description=(
|
|
"The ingest mode to use for the embedding engine:\n"
|
|
"If `simple` - ingest files sequentially and one by one. It is the historic behaviour.\n"
|
|
"If `batch` - if multiple files, parse all the files in parallel, "
|
|
"and send them in batch to the embedding model.\n"
|
|
"In `pipeline` - The Embedding engine is kept as busy as possible\n"
|
|
"If `parallel` - parse the files in parallel using multiple cores, and embedd them in parallel.\n"
|
|
"`parallel` is the fastest mode for local setup, as it parallelize IO RW in the index.\n"
|
|
"For modes that leverage parallelization, you can specify the number of "
|
|
"workers to use with `count_workers`.\n"
|
|
),
|
|
)
|
|
count_workers: int = Field(
|
|
2,
|
|
description=(
|
|
"The number of workers to use for file ingestion.\n"
|
|
"In `batch` mode, this is the number of workers used to parse the files.\n"
|
|
"In `parallel` mode, this is the number of workers used to parse the files and embed them.\n"
|
|
"In `pipeline` mode, this is the number of workers that can perform embeddings.\n"
|
|
"This is only used if `ingest_mode` is not `simple`.\n"
|
|
"Do not go too high with this number, as it might cause memory issues. (especially in `parallel` mode)\n"
|
|
"Do not set it higher than your number of threads of your CPU."
|
|
),
|
|
)
|
|
embed_dim: int = Field(
|
|
384,
|
|
description="The dimension of the embeddings stored in the Postgres database",
|
|
)
|
|
|
|
|
|
class SagemakerSettings(BaseModel):
|
|
llm_endpoint_name: str
|
|
embedding_endpoint_name: str
|
|
|
|
|
|
class OpenAISettings(BaseModel):
|
|
api_base: str = Field(
|
|
None,
|
|
description="Base URL of OpenAI API. Example: 'https://api.openai.com/v1'.",
|
|
)
|
|
api_key: str
|
|
model: str = Field(
|
|
"gpt-3.5-turbo",
|
|
description="OpenAI Model to use. Example: 'gpt-4'.",
|
|
)
|
|
request_timeout: float = Field(
|
|
120.0,
|
|
description="Time elapsed until openailike server times out the request. Default is 120s. Format is float. ",
|
|
)
|
|
embedding_api_base: str = Field(
|
|
None,
|
|
description="Base URL of OpenAI API. Example: 'https://api.openai.com/v1'.",
|
|
)
|
|
embedding_api_key: str
|
|
embedding_model: str = Field(
|
|
"text-embedding-ada-002",
|
|
description="OpenAI embedding Model to use. Example: 'text-embedding-3-large'.",
|
|
)
|
|
|
|
|
|
class GeminiSettings(BaseModel):
|
|
api_key: str
|
|
model: str = Field(
|
|
"models/gemini-pro",
|
|
description="Google Model to use. Example: 'models/gemini-pro'.",
|
|
)
|
|
embedding_model: str = Field(
|
|
"models/embedding-001",
|
|
description="Google Embedding Model to use. Example: 'models/embedding-001'.",
|
|
)
|
|
|
|
|
|
class OllamaSettings(BaseModel):
|
|
api_base: str = Field(
|
|
"http://localhost:11434",
|
|
description="Base URL of Ollama API. Example: 'https://localhost:11434'.",
|
|
)
|
|
embedding_api_base: str = Field(
|
|
"http://localhost:11434",
|
|
description="Base URL of Ollama embedding API. Example: 'https://localhost:11434'.",
|
|
)
|
|
llm_model: str = Field(
|
|
None,
|
|
description="Model to use. Example: 'llama2-uncensored'.",
|
|
)
|
|
embedding_model: str = Field(
|
|
None,
|
|
description="Model to use. Example: 'nomic-embed-text'.",
|
|
)
|
|
keep_alive: str = Field(
|
|
"5m",
|
|
description="Time the model will stay loaded in memory after a request. examples: 5m, 5h, '-1' ",
|
|
)
|
|
tfs_z: float = Field(
|
|
1.0,
|
|
description="Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting.",
|
|
)
|
|
num_predict: int = Field(
|
|
None,
|
|
description="Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context)",
|
|
)
|
|
top_k: int = Field(
|
|
40,
|
|
description="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)",
|
|
)
|
|
top_p: float = Field(
|
|
0.9,
|
|
description="Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)",
|
|
)
|
|
repeat_last_n: int = Field(
|
|
64,
|
|
description="Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)",
|
|
)
|
|
repeat_penalty: float = Field(
|
|
1.1,
|
|
description="Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)",
|
|
)
|
|
request_timeout: float = Field(
|
|
120.0,
|
|
description="Time elapsed until ollama times out the request. Default is 120s. Format is float. ",
|
|
)
|
|
autopull_models: bool = Field(
|
|
False,
|
|
description="If set to True, the Ollama will automatically pull the models from the API base.",
|
|
)
|
|
|
|
|
|
class AzureOpenAISettings(BaseModel):
|
|
api_key: str
|
|
azure_endpoint: str
|
|
api_version: str = Field(
|
|
"2023_05_15",
|
|
description="The API version to use for this operation. This follows the YYYY-MM-DD format.",
|
|
)
|
|
embedding_deployment_name: str
|
|
embedding_model: str = Field(
|
|
"text-embedding-ada-002",
|
|
description="OpenAI Model to use. Example: 'text-embedding-ada-002'.",
|
|
)
|
|
llm_deployment_name: str
|
|
llm_model: str = Field(
|
|
"gpt-35-turbo",
|
|
description="OpenAI Model to use. Example: 'gpt-4'.",
|
|
)
|
|
|
|
|
|
class UISettings(BaseModel):
|
|
enabled: bool
|
|
path: str
|
|
default_chat_system_prompt: str = Field(
|
|
None,
|
|
description="The default system prompt to use for the chat mode.",
|
|
)
|
|
default_query_system_prompt: str = Field(
|
|
None, description="The default system prompt to use for the query mode."
|
|
)
|
|
default_summarization_system_prompt: str = Field(
|
|
None,
|
|
description="The default system prompt to use for the summarization mode.",
|
|
)
|
|
delete_file_button_enabled: bool = Field(
|
|
True, description="If the button to delete a file is enabled or not."
|
|
)
|
|
delete_all_files_button_enabled: bool = Field(
|
|
False, description="If the button to delete all files is enabled or not."
|
|
)
|
|
|
|
|
|
class RerankSettings(BaseModel):
|
|
enabled: bool = Field(
|
|
False,
|
|
description="This value controls whether a reranker should be included in the RAG pipeline.",
|
|
)
|
|
model: str = Field(
|
|
"cross-encoder/ms-marco-MiniLM-L-2-v2",
|
|
description="Rerank model to use. Limited to SentenceTransformer cross-encoder models.",
|
|
)
|
|
top_n: int = Field(
|
|
2,
|
|
description="This value controls the number of documents returned by the RAG pipeline.",
|
|
)
|
|
|
|
|
|
class RagSettings(BaseModel):
|
|
similarity_top_k: int = Field(
|
|
2,
|
|
description="This value controls the number of documents returned by the RAG pipeline or considered for reranking if enabled.",
|
|
)
|
|
similarity_value: float = Field(
|
|
None,
|
|
description="If set, any documents retrieved from the RAG must meet a certain match score. Acceptable values are between 0 and 1.",
|
|
)
|
|
rerank: RerankSettings
|
|
|
|
|
|
class SummarizeSettings(BaseModel):
|
|
use_async: bool = Field(
|
|
True,
|
|
description="If set to True, the summarization will be done asynchronously.",
|
|
)
|
|
|
|
|
|
class ClickHouseSettings(BaseModel):
|
|
host: str = Field(
|
|
"localhost",
|
|
description="The server hosting the ClickHouse database",
|
|
)
|
|
port: int = Field(
|
|
8443,
|
|
description="The port on which the ClickHouse database is accessible",
|
|
)
|
|
username: str = Field(
|
|
"default",
|
|
description="The username to use to connect to the ClickHouse database",
|
|
)
|
|
password: str = Field(
|
|
"",
|
|
description="The password to use to connect to the ClickHouse database",
|
|
)
|
|
database: str = Field(
|
|
"__default__",
|
|
description="The default database to use for connections",
|
|
)
|
|
secure: bool | str = Field(
|
|
False,
|
|
description="Use https/TLS for secure connection to the server",
|
|
)
|
|
interface: str | None = Field(
|
|
None,
|
|
description="Must be either 'http' or 'https'. Determines the protocol to use for the connection",
|
|
)
|
|
settings: dict[str, Any] | None = Field(
|
|
None,
|
|
description="Specific ClickHouse server settings to be used with the session",
|
|
)
|
|
connect_timeout: int | None = Field(
|
|
None,
|
|
description="Timeout in seconds for establishing a connection",
|
|
)
|
|
send_receive_timeout: int | None = Field(
|
|
None,
|
|
description="Read timeout in seconds for http connection",
|
|
)
|
|
verify: bool | None = Field(
|
|
None,
|
|
description="Verify the server certificate in secure/https mode",
|
|
)
|
|
ca_cert: str | None = Field(
|
|
None,
|
|
description="Path to Certificate Authority root certificate (.pem format)",
|
|
)
|
|
client_cert: str | None = Field(
|
|
None,
|
|
description="Path to TLS Client certificate (.pem format)",
|
|
)
|
|
client_cert_key: str | None = Field(
|
|
None,
|
|
description="Path to the private key for the TLS Client certificate",
|
|
)
|
|
http_proxy: str | None = Field(
|
|
None,
|
|
description="HTTP proxy address",
|
|
)
|
|
https_proxy: str | None = Field(
|
|
None,
|
|
description="HTTPS proxy address",
|
|
)
|
|
server_host_name: str | None = Field(
|
|
None,
|
|
description="Server host name to be checked against the TLS certificate",
|
|
)
|
|
|
|
|
|
class PostgresSettings(BaseModel):
|
|
host: str = Field(
|
|
"localhost",
|
|
description="The server hosting the Postgres database",
|
|
)
|
|
port: int = Field(
|
|
5432,
|
|
description="The port on which the Postgres database is accessible",
|
|
)
|
|
user: str = Field(
|
|
"postgres",
|
|
description="The user to use to connect to the Postgres database",
|
|
)
|
|
password: str = Field(
|
|
"postgres",
|
|
description="The password to use to connect to the Postgres database",
|
|
)
|
|
database: str = Field(
|
|
"postgres",
|
|
description="The database to use to connect to the Postgres database",
|
|
)
|
|
schema_name: str = Field(
|
|
"public",
|
|
description="The name of the schema in the Postgres database to use",
|
|
)
|
|
|
|
|
|
class QdrantSettings(BaseModel):
|
|
location: str | None = Field(
|
|
None,
|
|
description=(
|
|
"If `:memory:` - use in-memory Qdrant instance.\n"
|
|
"If `str` - use it as a `url` parameter.\n"
|
|
),
|
|
)
|
|
url: str | None = Field(
|
|
None,
|
|
description=(
|
|
"Either host or str of 'Optional[scheme], host, Optional[port], Optional[prefix]'."
|
|
),
|
|
)
|
|
port: int | None = Field(6333, description="Port of the REST API interface.")
|
|
grpc_port: int | None = Field(6334, description="Port of the gRPC interface.")
|
|
prefer_grpc: bool | None = Field(
|
|
False,
|
|
description="If `true` - use gRPC interface whenever possible in custom methods.",
|
|
)
|
|
https: bool | None = Field(
|
|
None,
|
|
description="If `true` - use HTTPS(SSL) protocol.",
|
|
)
|
|
api_key: str | None = Field(
|
|
None,
|
|
description="API key for authentication in Qdrant Cloud.",
|
|
)
|
|
prefix: str | None = Field(
|
|
None,
|
|
description=(
|
|
"Prefix to add to the REST URL path."
|
|
"Example: `service/v1` will result in "
|
|
"'http://localhost:6333/service/v1/{qdrant-endpoint}' for REST API."
|
|
),
|
|
)
|
|
timeout: float | None = Field(
|
|
None,
|
|
description="Timeout for REST and gRPC API requests.",
|
|
)
|
|
host: str | None = Field(
|
|
None,
|
|
description="Host name of Qdrant service. If url and host are None, set to 'localhost'.",
|
|
)
|
|
path: str | None = Field(None, description="Persistence path for QdrantLocal.")
|
|
force_disable_check_same_thread: bool | None = Field(
|
|
True,
|
|
description=(
|
|
"For QdrantLocal, force disable check_same_thread. Default: `True`"
|
|
"Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient."
|
|
),
|
|
)
|
|
|
|
|
|
class MilvusSettings(BaseModel):
|
|
uri: str = Field(
|
|
"local_data/private_gpt/milvus/milvus_local.db",
|
|
description="The URI of the Milvus instance. For example: 'local_data/private_gpt/milvus/milvus_local.db' for Milvus Lite.",
|
|
)
|
|
token: str = Field(
|
|
"",
|
|
description=(
|
|
"A valid access token to access the specified Milvus instance. "
|
|
"This can be used as a recommended alternative to setting user and password separately. "
|
|
),
|
|
)
|
|
collection_name: str = Field(
|
|
"make_this_parameterizable_per_api_call",
|
|
description="The name of the collection in Milvus. Default is 'make_this_parameterizable_per_api_call'.",
|
|
)
|
|
overwrite: bool = Field(
|
|
True, description="Overwrite the previous collection schema if it exists."
|
|
)
|
|
|
|
|
|
class Settings(BaseModel):
|
|
server: ServerSettings
|
|
data: DataSettings
|
|
ui: UISettings
|
|
llm: LLMSettings
|
|
embedding: EmbeddingSettings
|
|
llamacpp: LlamaCPPSettings
|
|
huggingface: HuggingFaceSettings
|
|
sagemaker: SagemakerSettings
|
|
openai: OpenAISettings
|
|
gemini: GeminiSettings
|
|
ollama: OllamaSettings
|
|
azopenai: AzureOpenAISettings
|
|
vectorstore: VectorstoreSettings
|
|
nodestore: NodeStoreSettings
|
|
rag: RagSettings
|
|
summarize: SummarizeSettings
|
|
qdrant: QdrantSettings | None = None
|
|
postgres: PostgresSettings | None = None
|
|
clickhouse: ClickHouseSettings | None = None
|
|
milvus: MilvusSettings | None = None
|
|
|
|
|
|
"""
|
|
This is visible just for DI or testing purposes.
|
|
|
|
Use dependency injection or `settings()` method instead.
|
|
"""
|
|
unsafe_settings = load_active_settings()
|
|
|
|
"""
|
|
This is visible just for DI or testing purposes.
|
|
|
|
Use dependency injection or `settings()` method instead.
|
|
"""
|
|
unsafe_typed_settings = Settings(**unsafe_settings)
|
|
|
|
|
|
def settings() -> Settings:
|
|
"""Get the current loaded settings from the DI container.
|
|
|
|
This method exists to keep compatibility with the existing code,
|
|
that require global access to the settings.
|
|
|
|
For regular components use dependency injection instead.
|
|
"""
|
|
from private_gpt.di import global_injector
|
|
|
|
return global_injector.get(Settings)
|