# mypy: ignore-errors from __future__ import annotations import io import json import logging from typing import TYPE_CHECKING, Any import boto3 # type: ignore from llama_index.core.base.llms.generic_utils import ( completion_response_to_chat_response, stream_completion_response_to_chat_response, ) from llama_index.core.bridge.pydantic import Field from llama_index.core.llms import ( CompletionResponse, CustomLLM, LLMMetadata, ) from llama_index.core.llms.callbacks import ( llm_chat_callback, llm_completion_callback, ) if TYPE_CHECKING: from collections.abc import Sequence from llama_index.callbacks import CallbackManager from llama_index.llms import ( ChatMessage, ChatResponse, ChatResponseGen, CompletionResponseGen, ) logger = logging.getLogger(__name__) class LineIterator: r"""A helper class for parsing the byte stream input from TGI container. The output of the model will be in the following format: ``` b'data:{"token": {"text": " a"}}\n\n' b'data:{"token": {"text": " challenging"}}\n\n' b'data:{"token": {"text": " problem" b'}}' ... ``` While usually each PayloadPart event from the event stream will contain a byte array with a full json, this is not guaranteed and some of the json objects may be split across PayloadPart events. For example: ``` {'PayloadPart': {'Bytes': b'{"outputs": '}} {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} ``` This class accounts for this by concatenating bytes written via the 'write' function and then exposing a method which will return lines (ending with a '\n' character) within the buffer via the 'scan_lines' function. It maintains the position of the last read position to ensure that previous bytes are not exposed again. It will also save any pending lines that doe not end with a '\n' to make sure truncations are concatinated """ def __init__(self, stream: Any) -> None: """Line iterator initializer.""" self.byte_iterator = iter(stream) self.buffer = io.BytesIO() self.read_pos = 0 def __iter__(self) -> Any: """Self iterator.""" return self def __next__(self) -> Any: """Next element from iterator.""" while True: self.buffer.seek(self.read_pos) line = self.buffer.readline() if line and line[-1] == ord("\n"): self.read_pos += len(line) return line[:-1] try: chunk = next(self.byte_iterator) except StopIteration: if self.read_pos < self.buffer.getbuffer().nbytes: continue raise if "PayloadPart" not in chunk: logger.warning("Unknown event type=%s", chunk) continue self.buffer.seek(0, io.SEEK_END) self.buffer.write(chunk["PayloadPart"]["Bytes"]) class SagemakerLLM(CustomLLM): """Sagemaker Inference Endpoint models. To use, you must supply the endpoint name from your deployed Sagemaker model & the region where it is deployed. To authenticate, the AWS client uses the following methods to automatically load credentials: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html If a specific credential profile should be used, you must pass the name of the profile from the ~/.aws/credentials file that is to be used. Make sure the credentials / roles used have the required policies to access the Sagemaker endpoint. See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html """ endpoint_name: str = Field(description="") temperature: float = Field(description="The temperature to use for sampling.") max_new_tokens: int = Field(description="The maximum number of tokens to generate.") context_window: int = Field( description="The maximum number of context tokens for the model." ) messages_to_prompt: Any = Field( description="The function to convert messages to a prompt.", exclude=True ) completion_to_prompt: Any = Field( description="The function to convert a completion to a prompt.", exclude=True ) generate_kwargs: dict[str, Any] = Field( default_factory=dict, description="Kwargs used for generation." ) model_kwargs: dict[str, Any] = Field( default_factory=dict, description="Kwargs used for model initialization." ) verbose: bool = Field(description="Whether to print verbose output.") _boto_client: Any = boto3.client( "sagemaker-runtime", ) # TODO make it an optional field def __init__( self, endpoint_name: str | None = "", temperature: float = 0.1, max_new_tokens: int = 512, # to review defaults context_window: int = 2048, # to review defaults messages_to_prompt: Any = None, completion_to_prompt: Any = None, callback_manager: CallbackManager | None = None, generate_kwargs: dict[str, Any] | None = None, model_kwargs: dict[str, Any] | None = None, verbose: bool = True, ) -> None: """SagemakerLLM initializer.""" model_kwargs = model_kwargs or {} model_kwargs.update({"n_ctx": context_window, "verbose": verbose}) messages_to_prompt = messages_to_prompt or {} completion_to_prompt = completion_to_prompt or {} generate_kwargs = generate_kwargs or {} generate_kwargs.update( {"temperature": temperature, "max_tokens": max_new_tokens} ) super().__init__( endpoint_name=endpoint_name, temperature=temperature, context_window=context_window, max_new_tokens=max_new_tokens, messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, callback_manager=callback_manager, generate_kwargs=generate_kwargs, model_kwargs=model_kwargs, verbose=verbose, ) @property def inference_params(self): # TODO expose the rest of params return { "do_sample": True, "top_p": 0.7, "temperature": self.temperature, "top_k": 50, "max_new_tokens": self.max_new_tokens, } @property def metadata(self) -> LLMMetadata: """Get LLM metadata.""" return LLMMetadata( context_window=self.context_window, num_output=self.max_new_tokens, model_name="Sagemaker LLama 2", ) @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: self.generate_kwargs.update({"stream": False}) is_formatted = kwargs.pop("formatted", False) if not is_formatted: prompt = self.completion_to_prompt(prompt) request_params = { "inputs": prompt, "stream": False, "parameters": self.inference_params, } resp = self._boto_client.invoke_endpoint( EndpointName=self.endpoint_name, Body=json.dumps(request_params), ContentType="application/json", ) response_body = resp["Body"] response_str = response_body.read().decode("utf-8") response_dict = json.loads(response_str) return CompletionResponse( text=response_dict[0]["generated_text"][len(prompt) :], raw=resp ) @llm_completion_callback() def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: def get_stream(): text = "" request_params = { "inputs": prompt, "stream": True, "parameters": self.inference_params, } resp = self._boto_client.invoke_endpoint_with_response_stream( EndpointName=self.endpoint_name, Body=json.dumps(request_params), ContentType="application/json", ) event_stream = resp["Body"] start_json = b"{" stop_token = "<|endoftext|>" first_token = True for line in LineIterator(event_stream): if line != b"" and start_json in line: data = json.loads(line[line.find(start_json) :].decode("utf-8")) special = data["token"]["special"] stop = data["token"]["text"] == stop_token if not special and not stop: delta = data["token"]["text"] # trim the leading space for the first token if present if first_token: delta = delta.lstrip() first_token = False text += delta yield CompletionResponse(delta=delta, text=text, raw=data) return get_stream() @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: prompt = self.messages_to_prompt(messages) completion_response = self.complete(prompt, formatted=True, **kwargs) return completion_response_to_chat_response(completion_response) @llm_chat_callback() def stream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseGen: prompt = self.messages_to_prompt(messages) completion_response = self.stream_complete(prompt, formatted=True, **kwargs) return stream_completion_response_to_chat_response(completion_response)