"""Langchain Context-----------------This module contains functions for converting Chatsky's Message and Context to Langchain messages."""importreimportloggingfromtypingimportLiteral,Unionimportasynciofromchatsky.coreimportContext,Messagefromchatsky.llm._langchain_importsimportHumanMessage,SystemMessage,AIMessage,check_langchain_availablefromchatsky.llm.filtersimportBaseHistoryFilter,Returnfromchatsky.llm.promptimportPrompt,PositionConfiglogger=logging.getLogger(__name__)
[docs]asyncdefmessage_to_langchain(message:Message,ctx:Context,source:Literal["human","ai","system"]="human",max_size:int=5000)->Union[HumanMessage,AIMessage,SystemMessage]:""" Create a langchain message from a :py:class:`~chatsky.script.core.message.Message` object. :param message: Chatsky Message to convert to Langchain Message. :param ctx: Current dialog context. :param source: Source of the message [`human`, `ai`, `system`]. Defaults to "human". :param max_size: Maximum size of the message measured in characters. If a message exceeds the limit it will not be sent to the LLM and a warning will be produced """check_langchain_available()ifmessage.textisNone:content=[]eliflen(message.text)>max_size:logger.warning("Message is too long.")content=[]else:content=[{"type":"text","text":message.text}]ifsource=="human":returnHumanMessage(content=content)elifsource=="ai":returnAIMessage(content=content)elifsource=="system":returnSystemMessage(content=content)else:returnHumanMessage(content=content)
[docs]asyncdefcontext_to_history(ctx:Context,length:int,filter_func:BaseHistoryFilter,llm_model_name:str,max_size:int)->list[Union[HumanMessage,AIMessage,SystemMessage]]:""" Convert context to list of langchain messages. :param ctx: Current dialog context. :param length: Amount of turns to include in history. Set to `-1` to include all context. :param filter_func: Function to filter the context. :param llm_model_name: name of the model from the pipeline. :param max_size: Maximum size of the message in symbols. :return: List of Langchain message objects. """check_langchain_available()history=[]indices=list(range(1,ctx.current_turn_id))iflength==0:return[]eliflength>0:indices=indices[-length:]forrequest,responseinzip(*awaitasyncio.gather(ctx.requests.get(indices),ctx.responses.get(indices))):filter_result=filter_func(ctx,request,response,llm_model_name)ifrequestisnotNoneandfilter_resultin(Return.Request,Return.Turn):history.append(awaitmessage_to_langchain(request,ctx=ctx,max_size=max_size))ifresponseisnotNoneandfilter_resultin(Return.Response,Return.Turn):history.append(awaitmessage_to_langchain(response,ctx=ctx,source="ai",max_size=max_size))returnhistory
[docs]asyncdefget_langchain_context(system_prompt:Message,ctx:Context,call_prompt:Prompt,prompt_misc_filter:str=r"prompt",# r"prompt" -> extract misc promptsposition_config:PositionConfig=PositionConfig(),**history_args,)->list[Union[HumanMessage,AIMessage,SystemMessage]]:""" Get a list of Langchain messages using the context and prompts. :param system_prompt: System message to be included in the context. :param ctx: Current dialog context. :param call_prompt: Prompt to be used for the current call. :param prompt_misc_filter: Regex pattern to filter miscellaneous prompts from context. Defaults to r"prompt". :param position_config: Configuration for positioning different parts of the context. Defaults to default PositionConfig(). :param history_args: Additional arguments to be passed to context_to_history function. :return: List of Langchain message objects ordered by their position values. """check_langchain_available()logger.debug(f"History args: {history_args}")history=awaitcontext_to_history(ctx,**history_args)logger.debug(f"Position config: {position_config}")prompts:list[tuple[list[Union[HumanMessage,AIMessage,SystemMessage]],float]]=[]ifsystem_prompt.text!="":prompts.append(([awaitmessage_to_langchain(system_prompt,ctx,source="system")],position_config.system_prompt))prompts.append((history,position_config.history))logger.debug(f"System prompt: {prompts[0]}")forelement_name,elementinctx.current_node.misc.items():ifre.compile(prompt_misc_filter).match(element_name):prompt=Prompt.model_validate(element)prompt_langchain_message=awaitmessage_to_langchain(awaitprompt.message(ctx),ctx,source="human")prompts.append(([prompt_langchain_message],prompt.positionifprompt.positionisnotNoneelseposition_config.misc_prompt,))call_prompt_text=awaitcall_prompt.message(ctx)ifcall_prompt_text.text!="":call_prompt_message=awaitmessage_to_langchain(call_prompt_text,ctx,source="human")prompts.append(([call_prompt_message],call_prompt.positionifcall_prompt.positionisnotNoneelseposition_config.call_prompt,))last_turn_request=awaitctx.requests.get(ctx.current_turn_id)last_turn_response=awaitctx.responses.get(ctx.current_turn_id)iflast_turn_request:prompts.append(([awaitmessage_to_langchain(last_turn_request,ctx,source="human")],position_config.last_turn))iflast_turn_response:prompts.append(([awaitmessage_to_langchain(last_turn_response,ctx,source="ai")],position_config.last_turn))logger.debug(f"Prompts: {prompts}")prompts=sorted(prompts,key=lambdax:x[1])# flatten prompts listlangchain_context=[]formessage_blockinprompts:langchain_context.extend(message_block[0])returnlangchain_context