Source code for chatsky.llm.langchain_context

"""
Langchain Context
-----------------
This module contains functions for converting Chatsky's Message and Context to Langchain messages.
"""

import re
import logging
from typing import Literal, Union
import asyncio

from chatsky.core import Context, Message
from chatsky.llm._langchain_imports import HumanMessage, SystemMessage, AIMessage, check_langchain_available
from chatsky.llm.filters import BaseHistoryFilter, Return
from chatsky.llm.prompt import Prompt, PositionConfig


logger = logging.getLogger(__name__)


[docs] async def message_to_langchain( message: Message, ctx: Context, source: Literal["human", "ai", "system"] = "human", max_size: int = 5000 ) -> Union[HumanMessage, AIMessage, SystemMessage]: """ Create a langchain message from a :py:class:`~chatsky.script.core.message.Message` object. :param message: Chatsky Message to convert to Langchain Message. :param ctx: Current dialog context. :param source: Source of the message [`human`, `ai`, `system`]. Defaults to "human". :param max_size: Maximum size of the message measured in characters. If a message exceeds the limit it will not be sent to the LLM and a warning will be produced """ check_langchain_available() if message.text is None: content = [] elif len(message.text) > max_size: logger.warning("Message is too long.") content = [] else: content = [{"type": "text", "text": message.text}] if source == "human": return HumanMessage(content=content) elif source == "ai": return AIMessage(content=content) elif source == "system": return SystemMessage(content=content) else: return HumanMessage(content=content)
[docs] async def context_to_history( ctx: Context, length: int, filter_func: BaseHistoryFilter, llm_model_name: str, max_size: int ) -> list[Union[HumanMessage, AIMessage, SystemMessage]]: """ Convert context to list of langchain messages. :param ctx: Current dialog context. :param length: Amount of turns to include in history. Set to `-1` to include all context. :param filter_func: Function to filter the context. :param llm_model_name: name of the model from the pipeline. :param max_size: Maximum size of the message in symbols. :return: List of Langchain message objects. """ check_langchain_available() history = [] indices = list(range(1, ctx.current_turn_id)) if length == 0: return [] elif length > 0: indices = indices[-length:] for request, response in zip(*await asyncio.gather(ctx.requests.get(indices), ctx.responses.get(indices))): filter_result = filter_func(ctx, request, response, llm_model_name) if request is not None and filter_result in (Return.Request, Return.Turn): history.append(await message_to_langchain(request, ctx=ctx, max_size=max_size)) if response is not None and filter_result in (Return.Response, Return.Turn): history.append(await message_to_langchain(response, ctx=ctx, source="ai", max_size=max_size)) return history
[docs] async def get_langchain_context( system_prompt: Message, ctx: Context, call_prompt: Prompt, prompt_misc_filter: str = r"prompt", # r"prompt" -> extract misc prompts position_config: PositionConfig = PositionConfig(), **history_args, ) -> list[Union[HumanMessage, AIMessage, SystemMessage]]: """ Get a list of Langchain messages using the context and prompts. :param system_prompt: System message to be included in the context. :param ctx: Current dialog context. :param call_prompt: Prompt to be used for the current call. :param prompt_misc_filter: Regex pattern to filter miscellaneous prompts from context. Defaults to r"prompt". :param position_config: Configuration for positioning different parts of the context. Defaults to default PositionConfig(). :param history_args: Additional arguments to be passed to context_to_history function. :return: List of Langchain message objects ordered by their position values. """ check_langchain_available() logger.debug(f"History args: {history_args}") history = await context_to_history(ctx, **history_args) logger.debug(f"Position config: {position_config}") prompts: list[tuple[list[Union[HumanMessage, AIMessage, SystemMessage]], float]] = [] if system_prompt.text != "": prompts.append( ([await message_to_langchain(system_prompt, ctx, source="system")], position_config.system_prompt) ) prompts.append((history, position_config.history)) logger.debug(f"System prompt: {prompts[0]}") for element_name, element in ctx.current_node.misc.items(): if re.compile(prompt_misc_filter).match(element_name): prompt = Prompt.model_validate(element) prompt_langchain_message = await message_to_langchain(await prompt.message(ctx), ctx, source="human") prompts.append( ( [prompt_langchain_message], prompt.position if prompt.position is not None else position_config.misc_prompt, ) ) call_prompt_text = await call_prompt.message(ctx) if call_prompt_text.text != "": call_prompt_message = await message_to_langchain(call_prompt_text, ctx, source="human") prompts.append( ( [call_prompt_message], call_prompt.position if call_prompt.position is not None else position_config.call_prompt, ) ) last_turn_request = await ctx.requests.get(ctx.current_turn_id) last_turn_response = await ctx.responses.get(ctx.current_turn_id) if last_turn_request: prompts.append( ([await message_to_langchain(last_turn_request, ctx, source="human")], position_config.last_turn) ) if last_turn_response: prompts.append(([await message_to_langchain(last_turn_response, ctx, source="ai")], position_config.last_turn)) logger.debug(f"Prompts: {prompts}") prompts = sorted(prompts, key=lambda x: x[1]) # flatten prompts list langchain_context = [] for message_block in prompts: langchain_context.extend(message_block[0]) return langchain_context