"""
LLM Utils.
----------
The Utils module contains functions for converting Chatsky's objects to an LLM_API and langchain compatible versions.
"""
import re
import logging
from typing import Literal, Union
import asyncio
from chatsky.core import Context, Message
from chatsky.llm._langchain_imports import HumanMessage, SystemMessage, AIMessage, check_langchain_available
from chatsky.llm.filters import BaseHistoryFilter, Return
from chatsky.llm.prompt import Prompt, PositionConfig
logger = logging.getLogger(__name__)
[docs]
async def message_to_langchain(
message: Message, ctx: Context, source: Literal["human", "ai", "system"] = "human", max_size: int = 5000
) -> Union[HumanMessage, AIMessage, SystemMessage]:
"""
Create a langchain message from a :py:class:`~chatsky.script.core.message.Message` object.
:param message: Chatsky Message to convert to Langchain Message.
:param ctx: Current dialog context.
:param source: Source of the message [`human`, `ai`, `system`]. Defaults to "human".
:param max_size: Maximum size of the message measured in characters.
If a message exceeds the limit it will not be sent to the LLM and a warning
will be produced
"""
check_langchain_available()
if message.text is None:
content = []
elif len(message.text) > max_size:
logger.warning("Message is too long.")
content = []
else:
content = [{"type": "text", "text": message.text}]
if source == "human":
return HumanMessage(content=content)
elif source == "ai":
return AIMessage(content=content)
elif source == "system":
return SystemMessage(content=content)
else:
return HumanMessage(content=content)
[docs]
async def context_to_history(
ctx: Context, length: int, filter_func: BaseHistoryFilter, llm_model_name: str, max_size: int
) -> list[Union[HumanMessage, AIMessage, SystemMessage]]:
"""
Convert context to list of langchain messages.
:param ctx: Current dialog context.
:param length: Amount of turns to include in history. Set to `-1` to include all context.
:param filter_func: Function to filter the context.
:param llm_model_name: name of the model from the pipeline.
:param max_size: Maximum size of the message in symbols.
:return: List of Langchain message objects.
"""
check_langchain_available()
history = []
indices = list(range(1, ctx.current_turn_id))
if length == 0:
return []
elif length > 0:
indices = indices[-length:]
for request, response in zip(*await asyncio.gather(ctx.requests.get(indices), ctx.responses.get(indices))):
filter_result = filter_func(ctx, request, response, llm_model_name)
if request is not None and filter_result in (Return.Request, Return.Turn):
history.append(await message_to_langchain(request, ctx=ctx, max_size=max_size))
if response is not None and filter_result in (Return.Response, Return.Turn):
history.append(await message_to_langchain(response, ctx=ctx, source="ai", max_size=max_size))
return history
[docs]
async def get_langchain_context(
system_prompt: Message,
ctx: Context,
call_prompt: Prompt,
prompt_misc_filter: str = r"prompt", # r"prompt" -> extract misc prompts
position_config: PositionConfig = PositionConfig(),
**history_args,
) -> list[Union[HumanMessage, AIMessage, SystemMessage]]:
"""
Get a list of Langchain messages using the context and prompts.
:param system_prompt: System message to be included in the context.
:param ctx: Current dialog context.
:param call_prompt: Prompt to be used for the current call.
:param prompt_misc_filter: Regex pattern to filter miscellaneous prompts from context.
Defaults to r"prompt".
:param position_config: Configuration for positioning different parts of the context.
Defaults to default PositionConfig().
:param history_args: Additional arguments to be passed to context_to_history function.
:return: List of Langchain message objects ordered by their position values.
"""
check_langchain_available()
logger.debug(f"History args: {history_args}")
history = await context_to_history(ctx, **history_args)
logger.debug(f"Position config: {position_config}")
prompts: list[tuple[list[Union[HumanMessage, AIMessage, SystemMessage]], float]] = []
if system_prompt.text != "":
prompts.append(
([await message_to_langchain(system_prompt, ctx, source="system")], position_config.system_prompt)
)
prompts.append((history, position_config.history))
logger.debug(f"System prompt: {prompts[0]}")
for element_name, element in ctx.current_node.misc.items():
if re.compile(prompt_misc_filter).match(element_name):
prompt = Prompt.model_validate(element)
prompt_langchain_message = await message_to_langchain(await prompt.message(ctx), ctx, source="human")
prompts.append(
(
[prompt_langchain_message],
prompt.position if prompt.position is not None else position_config.misc_prompt,
)
)
call_prompt_text = await call_prompt.message(ctx)
if call_prompt_text.text != "":
call_prompt_message = await message_to_langchain(call_prompt_text, ctx, source="human")
prompts.append(
(
[call_prompt_message],
call_prompt.position if call_prompt.position is not None else position_config.call_prompt,
)
)
last_turn_request = await ctx.requests.get(ctx.current_turn_id)
last_turn_response = await ctx.responses.get(ctx.current_turn_id)
if last_turn_request:
prompts.append(
([await message_to_langchain(last_turn_request, ctx, source="human")], position_config.last_turn)
)
if last_turn_response:
prompts.append(([await message_to_langchain(last_turn_response, ctx, source="ai")], position_config.last_turn))
logger.debug(f"Prompts: {prompts}")
prompts = sorted(prompts, key=lambda x: x[1])
# flatten prompts list
langchain_context = []
for message_block in prompts:
langchain_context.extend(message_block[0])
return langchain_context