"""
Slots
-----
This module defines base classes for slots and some concrete implementations of them.
"""
from __future__ import annotations
import asyncio
import re
from abc import ABC, abstractmethod
from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Optional, Dict
from typing_extensions import TypeAlias, Annotated
import logging
from functools import reduce
from string import Formatter
from pydantic import BaseModel, model_validator, Field, field_serializer, field_validator
from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async
from chatsky.utils.devel.json_serialization import pickle_serializer, pickle_validator
if TYPE_CHECKING:
from chatsky.core import Context, Message
logger = logging.getLogger(__name__)
SlotName: TypeAlias = str
"""
A string to identify slots.
Top-level slots are identified by their key in a :py:class:`~.GroupSlot`.
E.g.
.. code:: python
GroupSlot(
user=RegexpSlot(),
password=FunctionSlot,
)
Has two slots with names "user" and "password".
For nested group slots use dots to separate names:
.. code:: python
GroupSlot(
user=GroupSlot(
name=FunctionSlot,
password=FunctionSlot,
)
)
Has two slots with names "user.name" and "user.password".
"""
[docs]
def recursive_getattr(obj, slot_name: SlotName):
def two_arg_getattr(__o, name):
# pydantic handles exception when accessing a non-existing extra-field on its own
# return None by default to avoid that
return getattr(__o, name, None)
return reduce(two_arg_getattr, [obj, *slot_name.split(".")])
[docs]
def recursive_setattr(obj, slot_name: SlotName, value):
parent_slot, sep, slot = slot_name.rpartition(".")
if sep == ".":
parent_obj = recursive_getattr(obj, parent_slot)
else:
parent_obj = obj
if isinstance(value, ExtractedGroupSlot):
getattr(parent_obj, slot).update(value)
else:
setattr(parent_obj, slot, value)
[docs]
class BaseSlot(BaseModel, frozen=True):
"""
BaseSlot is a base class for all slots.
"""
[docs]
@abstractmethod
async def get_value(self, ctx: Context) -> ExtractedSlot:
"""
Extract slot value from :py:class:`~.Context` and return an instance of :py:class:`~.ExtractedSlot`.
"""
raise NotImplementedError
[docs]
@abstractmethod
def init_value(self) -> ExtractedSlot:
"""
Provide an initial value to fill slot storage with.
"""
raise NotImplementedError
[docs]
class ValueSlot(BaseSlot, frozen=True):
"""
Value slot is a base class for all slots that are designed to extract concrete values.
Subclass it, if you want to declare your own slot type.
"""
default_value: Any = None
[docs]
async def get_value(self, ctx: Context) -> ExtractedValueSlot:
"""Wrapper for :py:meth:`~.ValueSlot.extract_value` to handle exceptions."""
extracted_value = SlotNotExtracted("Caught an exit exception.")
is_slot_extracted = False
try:
extracted_value = await wrap_sync_function_in_async(self.extract_value, ctx)
is_slot_extracted = not isinstance(extracted_value, SlotNotExtracted)
except Exception as error:
logger.exception(f"Exception occurred during {self.__class__.__name__!r} extraction.", exc_info=error)
extracted_value = error
finally:
if not is_slot_extracted:
logger.debug(f"Slot {self.__class__.__name__!r} was not extracted: {extracted_value}")
return ExtractedValueSlot.model_construct(
is_slot_extracted=is_slot_extracted,
extracted_value=extracted_value,
default_value=self.default_value,
)
[docs]
def init_value(self) -> ExtractedValueSlot:
return ExtractedValueSlot.model_construct(
is_slot_extracted=False,
extracted_value=SlotNotExtracted("Initial slot extraction."),
default_value=self.default_value,
)
[docs]
class GroupSlot(BaseSlot, extra="allow", frozen=True):
"""
Base class for :py:class:`~.RootSlot` and :py:class:`~.GroupSlot`.
"""
__pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]]
allow_partial_extraction: bool = False
"""If True, extraction returns only successfully extracted child slots."""
def __init__(self, allow_partial_extraction=False, **kwargs):
super().__init__(allow_partial_extraction=allow_partial_extraction, **kwargs)
@model_validator(mode="after")
def __check_extra_field_names__(self):
"""
Extra field names cannot be dunder names or contain dots.
"""
for field in self.__pydantic_extra__.keys():
if "." in field:
raise ValueError(f"Extra field name cannot contain dots: {field!r}")
if field.startswith("__") and field.endswith("__"):
raise ValueError(f"Extra field names cannot be dunder: {field!r}")
return self
[docs]
async def get_value(self, ctx: Context) -> ExtractedGroupSlot:
child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.__pydantic_extra__.values()))
extracted_values = {}
for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys()):
if child_value.__slot_extracted__ or not self.allow_partial_extraction:
extracted_values[child_name] = child_value
return ExtractedGroupSlot(**extracted_values)
[docs]
def init_value(self) -> ExtractedGroupSlot:
return ExtractedGroupSlot(
**{child_name: child.init_value() for child_name, child in self.__pydantic_extra__.items()}
)
[docs]
class RegexpSlot(ValueSlot, frozen=True):
"""
RegexpSlot is a slot type that extracts its value using a regular expression.
You can pass a compiled or a non-compiled pattern to the `regexp` argument.
If you want to extract a particular group, but not the full match,
change the `match_group_idx` parameter.
"""
regexp: str
match_group_idx: int = 0
"Index of the group to match."
[docs]
class FunctionSlot(ValueSlot, frozen=True):
"""
A simpler version of :py:class:`~.ValueSlot`.
Uses a user-defined `func` to extract slot value from the :py:attr:`~.Context.last_request` Message.
"""
func: Callable[[Message], Union[Awaitable[Union[Any, SlotNotExtracted]], Any, SlotNotExtracted]]
[docs]
class SlotManager(BaseModel):
"""
Provides API for managing slots.
An instance of this class can be accessed via ``ctx.framework_data.slot_manager``.
"""
slot_storage: ExtractedGroupSlot = Field(default_factory=ExtractedGroupSlot)
"""Slot storage. Stored inside ctx.framework_data."""
root_slot: GroupSlot = Field(default_factory=GroupSlot, exclude=True)
"""Slot configuration passed during pipeline initialization."""
[docs]
def set_root_slot(self, root_slot: GroupSlot):
"""
Set root_slot configuration from pipeline.
Update extracted slots with the new configuration:
New slots are added with their :py:meth:`~.BaseSlot.init_value`.
Old extracted slot values are preserved only if their configuration did not change.
That is if they are still present in the config and if their fundamental type did not change
(i.e. `GroupSlot` did not turn into a `ValueSlot` or vice versa).
This method is called by pipeline and is not supposed to be used otherwise.
"""
self.root_slot = root_slot
new_slot_storage = root_slot.init_value()
new_slot_storage.update(self.slot_storage)
self.slot_storage = new_slot_storage
[docs]
def get_slot(self, slot_name: SlotName) -> BaseSlot:
"""
Get slot configuration from the slot name.
:raises KeyError: If the slot with the specified name does not exist.
"""
slot = recursive_getattr(self.root_slot, slot_name)
if isinstance(slot, BaseSlot):
return slot
raise KeyError(f"Could not find slot {slot_name!r}.")
[docs]
def unset_slot(self, slot_name: SlotName) -> None:
"""
Mark specified slot as not extracted and clear extracted value.
:raises KeyError: If the slot with the specified name does not exist.
"""
self.get_extracted_slot(slot_name).__unset__()
[docs]
def unset_all_slots(self) -> None:
"""
Mark all slots as not extracted and clear all extracted values.
"""
self.slot_storage.__unset__()
[docs]
def fill_template(self, template: str) -> Optional[str]:
"""
Fill `template` string with extracted slot values and return a formatted string
or None if an exception has occurred while trying to fill template.
`template` should be a format-string:
E.g. "Your username is {profile.username}".
For the example above, if ``profile.username`` slot has value "admin",
it would return the following text:
"Your username is admin".
"""
try:
return self.KwargOnlyFormatter().format(template, **dict(self.slot_storage.__pydantic_extra__.items()))
except Exception as exc:
logger.exception("An exception occurred during template filling.", exc_info=exc)
return None