Source code for chatsky.context_storages.sql

"""
SQL
---
The SQL module provides a SQL-based version of the :py:class:`.DBContextStorage` class.
This class is used to store and retrieve context data from SQL databases.
It allows Chatsky to easily store and retrieve context data in a format that is highly scalable
and easy to work with.

The SQL module provides the ability to choose the backend of your choice from
MySQL, PostgreSQL, or SQLite. You can choose the one that is most suitable for your use case and environment.
MySQL and PostgreSQL are widely used open-source relational databases that are known for their
reliability and scalability. SQLite is a self-contained, high-reliability, embedded, full-featured,
public-domain, SQL database engine.
"""

import asyncio
import importlib
import json
from typing import Hashable

from chatsky.core import Context

from .database import DBContextStorage, threadsafe_method
from .protocol import get_protocol_install_suggestion

try:
    from sqlalchemy import Table, MetaData, Column, JSON, String, inspect, select, delete, func
    from sqlalchemy.ext.asyncio import create_async_engine

    sqlalchemy_available = True
except (ImportError, ModuleNotFoundError):
    sqlalchemy_available = False

postgres_available = sqlite_available = mysql_available = False

try:
    import asyncpg

    _ = asyncpg

    postgres_available = True
except (ImportError, ModuleNotFoundError):
    pass

try:
    import asyncmy

    _ = asyncmy

    mysql_available = True
except (ImportError, ModuleNotFoundError):
    pass

try:
    import aiosqlite

    _ = aiosqlite

    sqlite_available = True
except (ImportError, ModuleNotFoundError):
    pass

if not sqlalchemy_available:
    postgres_available = sqlite_available = mysql_available = False


[docs]def import_insert_for_dialect(dialect: str): """ Imports the insert function into global scope depending on the chosen sqlalchemy dialect. :param dialect: Chosen sqlalchemy dialect. """ global insert insert = getattr( importlib.import_module(f"sqlalchemy.dialects.{dialect}"), "insert", )
[docs]class SQLContextStorage(DBContextStorage): """ | SQL-based version of the :py:class:`.DBContextStorage`. | Compatible with MySQL, Postgresql, Sqlite. :param path: Standard sqlalchemy URI string. When using sqlite backend in Windows, keep in mind that you have to use double backslashes '\\' instead of forward slashes '/' in the file path. :param table_name: The name of the table to use. :param custom_driver: If you intend to use some other database driver instead of the recommended ones, set this parameter to `True` to bypass the import checks. """ def __init__(self, path: str, table_name: str = "contexts", custom_driver: bool = False): DBContextStorage.__init__(self, path) self._check_availability(custom_driver) self.engine = create_async_engine(self.full_path, pool_pre_ping=True) self.dialect: str = self.engine.dialect.name id_column_args = {"primary_key": True} if self.dialect == "sqlite": id_column_args["sqlite_on_conflict_primary_key"] = "REPLACE" self.metadata = MetaData() self.table = Table( table_name, self.metadata, Column("id", String(36), **id_column_args), Column("context", JSON), # column for storing serialized contexts ) asyncio.run(self._create_self_table()) import_insert_for_dialect(self.dialect)
[docs] @threadsafe_method async def set_item_async(self, key: Hashable, value: Context): value = Context.model_validate(value) value = json.loads(value.model_dump_json()) insert_stmt = insert(self.table).values(id=str(key), context=value) update_stmt = await self._get_update_stmt(insert_stmt) async with self.engine.connect() as conn: await conn.execute(update_stmt) await conn.commit()
[docs] @threadsafe_method async def get_item_async(self, key: Hashable) -> Context: stmt = select(self.table.c.context).where(self.table.c.id == str(key)) async with self.engine.connect() as conn: result = await conn.execute(stmt) row = result.fetchone() if row: return Context.model_validate(row[0]) raise KeyError
[docs] @threadsafe_method async def del_item_async(self, key: Hashable): stmt = delete(self.table).where(self.table.c.id == str(key)) async with self.engine.connect() as conn: await conn.execute(stmt) await conn.commit()
[docs] @threadsafe_method async def contains_async(self, key: Hashable) -> bool: stmt = select(self.table.c.context).where(self.table.c.id == str(key)) async with self.engine.connect() as conn: result = await conn.execute(stmt) return bool(result.fetchone())
[docs] @threadsafe_method async def len_async(self) -> int: stmt = select(func.count()).select_from(self.table) async with self.engine.connect() as conn: result = await conn.execute(stmt) return result.fetchone()[0]
[docs] @threadsafe_method async def clear_async(self): stmt = delete(self.table) async with self.engine.connect() as conn: await conn.execute(stmt) await conn.commit()
[docs] async def _create_self_table(self): async with self.engine.begin() as conn: if not await conn.run_sync(lambda sync_conn: inspect(sync_conn).has_table(self.table.name)): await conn.run_sync(self.table.create, self.engine)
[docs] async def _get_update_stmt(self, insert_stmt): if self.dialect == "sqlite": return insert_stmt elif self.dialect == "mysql": update_stmt = insert_stmt.on_duplicate_key_update(context=insert_stmt.inserted.context) else: update_stmt = insert_stmt.on_conflict_do_update( index_elements=["id"], set_=dict(context=insert_stmt.excluded.context) ) return update_stmt
[docs] def _check_availability(self, custom_driver: bool): if not custom_driver: if self.full_path.startswith("postgresql") and not postgres_available: install_suggestion = get_protocol_install_suggestion("postgresql") raise ImportError("Packages `sqlalchemy` and/or `asyncpg` are missing.\n" + install_suggestion) elif self.full_path.startswith("mysql") and not mysql_available: install_suggestion = get_protocol_install_suggestion("mysql") raise ImportError("Packages `sqlalchemy` and/or `asyncmy` are missing.\n" + install_suggestion) elif self.full_path.startswith("sqlite") and not sqlite_available: install_suggestion = get_protocol_install_suggestion("sqlite") raise ImportError("Package `sqlalchemy` and/or `aiosqlite` is missing.\n" + install_suggestion)