Source code for dialog2graph.pipelines.helpers.find_cycle_ends
"""
Find cycle ends
----------------
The module provides graph auxilary method to find cycle ends.
"""
import json
from pydantic import BaseModel, Field
from langchain_core.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from dialog2graph.pipelines.core.graph import Graph
from langchain_core.language_models.chat_models import BaseChatModel
[docs]
def find_cycle_ends(G: Graph, cycle_ends_model: BaseChatModel) -> dict[str]:
"""
Find nodes in a dialog graph G using conditions in graph_ends_prompt_template with the help of model.
Parameters:
G (BaseGraph): The dialog graph
cycle_ends_model (BaseChatModel): The LLM to be used
Returns:
dict: {'value': bool, 'description': str}
"""
# Define validation result model
class GraphEndsResult(BaseModel):
ends: list = Field(description="IDs of ending nodes")
description: str = Field(description="Explanation of model's decision")
# Create prompt template
graph_ends_prompt_template = """
Your task is to find IDs of all the nodes satisfying condition below:
Let's consider node with id A.
There is only edge with source=A in the whole graph, and target of this edge is located earlier in the dialog flow.
Given this conversation graph in JSON:
{json_graph}
Reply in JSON format:
{{"ends": [id1, id2, ...], "description": "Brief explanation of your decision"}}
"""
graph_ends_prompt = PromptTemplate(
input_variables=["json_graph"], template=graph_ends_prompt_template
)
parser = PydanticOutputParser(pydantic_object=GraphEndsResult)
# Convert graph to JSON string
graph_json = json.dumps(G.graph_dict)
# Prepare input for validation
input_data = {
"json_graph": graph_json,
}
find_ends_chain = graph_ends_prompt | cycle_ends_model | parser
response = find_ends_chain.invoke(input_data)
result = {"value": response.ends, "description": response.description}
return result