# mypy: ignore-errors import json from typing import Any import boto3 from llama_index.core.base.embeddings.base import BaseEmbedding from pydantic import Field, PrivateAttr class SagemakerEmbedding(BaseEmbedding): """Sagemaker Embedding Endpoint. To use, you must supply the endpoint name from your deployed Sagemaker embedding model & the region where it is deployed. To authenticate, the AWS client uses the following methods to automatically load credentials: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html If a specific credential profile should be used, you must pass the name of the profile from the ~/.aws/credentials file that is to be used. Make sure the credentials / roles used have the required policies to access the Sagemaker endpoint. See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html """ endpoint_name: str = Field(description="") _boto_client: Any = boto3.client( "sagemaker-runtime", ) # TODO make it an optional field _async_not_implemented_warned: bool = PrivateAttr(default=False) @classmethod def class_name(cls) -> str: return "SagemakerEmbedding" def _async_not_implemented_warn_once(self) -> None: if not self._async_not_implemented_warned: print("Async embedding not available, falling back to sync method.") self._async_not_implemented_warned = True def _embed(self, sentences: list[str]) -> list[list[float]]: request_params = { "inputs": sentences, } resp = self._boto_client.invoke_endpoint( EndpointName=self.endpoint_name, Body=json.dumps(request_params), ContentType="application/json", ) response_body = resp["Body"] response_str = response_body.read().decode("utf-8") response_json = json.loads(response_str) return response_json["vectors"] def _get_query_embedding(self, query: str) -> list[float]: """Get query embedding.""" return self._embed([query])[0] async def _aget_query_embedding(self, query: str) -> list[float]: # Warn the user that sync is being used self._async_not_implemented_warn_once() return self._get_query_embedding(query) async def _aget_text_embedding(self, text: str) -> list[float]: # Warn the user that sync is being used self._async_not_implemented_warn_once() return self._get_text_embedding(text) def _get_text_embedding(self, text: str) -> list[float]: """Get text embedding.""" return self._embed([text])[0] def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]: """Get text embeddings.""" return self._embed(texts)