"""
Mongo
-----
The Mongo module provides a MongoDB-based version of the :py:class:`.DBContextStorage` class.
This class is used to store and retrieve context data in a MongoDB.
It allows Chatsky to easily store and retrieve context data in a format that is highly scalable
and easy to work with.
MongoDB is a widely-used, open-source NoSQL database that is known for its scalability and performance.
It stores data in a format similar to JSON, making it easy to work with the data in a variety of programming languages
and environments. Additionally, MongoDB is highly scalable and can handle large amounts of data
and high levels of read and write traffic.
"""
from asyncio import gather
from typing import Any, Dict, Set, Tuple, Optional, List
try:
from pymongo import UpdateOne
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorClientSession
mongo_available = True
except ImportError:
AsyncIOMotorClientSession = Any
mongo_available = False
from chatsky.core.ctx_utils import ContextMainInfo
from .database import DBContextStorage, _SUBSCRIPT_DICT, NameConfig
from .protocol import get_protocol_install_suggestion
[docs]
class MongoContextStorage(DBContextStorage):
"""
Implements :py:class:`.DBContextStorage` with `mongodb` as the database backend.
CONTEXTS table is stored as `COLLECTION_PREFIX_contexts` collection.
LOGS table is stored as `COLLECTION_PREFIX_logs` collection.
:param path: Database URI. Example: `mongodb://user:password@host:port/dbname`.
:param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not.
:param partial_read_config: Dictionary of subscripts for all possible turn items.
:param collection_prefix: "namespace" prefix for the two collections created for context storing.
"""
_UNIQUE_KEYS = "unique_keys"
_ID_FIELD = "_id"
is_concurrent: bool = True
def __init__(
self,
path: str,
rewrite_existing: bool = False,
partial_read_config: Optional[_SUBSCRIPT_DICT] = None,
collection_prefix: str = "chatsky_collection",
transactions_enabled: bool = False,
):
DBContextStorage.__init__(self, path, rewrite_existing, partial_read_config)
if not mongo_available:
install_suggestion = get_protocol_install_suggestion("mongodb")
raise ImportError("`mongodb` package is missing.\n" + install_suggestion)
self._transactions_enabled = transactions_enabled
self._mongo = AsyncIOMotorClient(self.full_path, uuidRepresentation="standard")
db = self._mongo.get_default_database()
self.main_table = db[f"{collection_prefix}_{NameConfig._main_table}"]
self.turns_table = db[f"{collection_prefix}_{NameConfig._turns_table}"]
[docs]
async def _connect(self):
await gather(
self.main_table.create_index(NameConfig._id_column, background=True, unique=True),
self.turns_table.create_index(
[NameConfig._id_column, NameConfig._key_column], background=True, unique=True
),
)
[docs]
async def _load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]:
result = await self.main_table.find_one(
{NameConfig._id_column: ctx_id},
NameConfig.get_context_main_fields,
)
return (
ContextMainInfo.model_validate({f: result[f] for f in NameConfig.get_context_main_fields})
if result is not None
else None
)
[docs]
async def _inner_update_context(
self,
ctx_id: str,
ctx_info_dump: Optional[Dict],
field_info: List[Tuple[str, List[Tuple[int, Optional[bytes]]]]],
session: Optional[AsyncIOMotorClientSession],
) -> None:
if ctx_info_dump is not None:
await self.main_table.update_one(
{NameConfig._id_column: ctx_id},
{
"$set": {
NameConfig._id_column: ctx_id,
}
| {f: ctx_info_dump[f] for f in NameConfig.get_context_main_fields}
},
upsert=True,
session=session,
)
if len(field_info) > 0:
await self.turns_table.bulk_write(
[
UpdateOne(
{NameConfig._id_column: ctx_id, NameConfig._key_column: k},
{"$set": {field_name: v}},
upsert=True,
)
for field_name, items in field_info
for k, v in items
],
session=session,
)
[docs]
async def _update_context(
self,
ctx_id: str,
ctx_info: Optional[ContextMainInfo],
field_info: List[Tuple[str, List[Tuple[int, Optional[bytes]]]]],
) -> None:
ctx_info_dump = ctx_info.model_dump(mode="python") if ctx_info is not None else None
if self._transactions_enabled:
async with await self._mongo.start_session() as session:
async with session.start_transaction():
await self._inner_update_context(ctx_id, ctx_info_dump, field_info, session)
else:
await self._inner_update_context(ctx_id, ctx_info_dump, field_info, None)
[docs]
async def _delete_context(self, ctx_id: str) -> None:
await gather(
self.main_table.delete_one({NameConfig._id_column: ctx_id}),
self.turns_table.delete_one({NameConfig._id_column: ctx_id}),
)
[docs]
async def _load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]:
limit, key = 0, dict()
if isinstance(self._subscripts[field_name], int):
limit = self._subscripts[field_name]
elif isinstance(self._subscripts[field_name], Set):
key = {NameConfig._key_column: {"$in": list(self._subscripts[field_name])}}
result = (
await self.turns_table.find(
{NameConfig._id_column: ctx_id, field_name: {"$exists": True, "$ne": None}, **key},
[NameConfig._key_column, field_name],
sort=[(NameConfig._key_column, -1)],
)
.limit(limit)
.to_list(None)
)
return [(item[NameConfig._key_column], item[field_name]) for item in result]
[docs]
async def _load_field_keys(self, ctx_id: str, field_name: str) -> List[int]:
result = await self.turns_table.aggregate(
[
{"$match": {NameConfig._id_column: ctx_id, field_name: {"$ne": None}}},
{"$group": {"_id": None, self._UNIQUE_KEYS: {"$addToSet": f"${NameConfig._key_column}"}}},
]
).to_list(None)
return result[0][self._UNIQUE_KEYS] if len(result) == 1 else list()
[docs]
async def _load_field_items(self, ctx_id: str, field_name: str, keys: Set[int]) -> List[Tuple[int, bytes]]:
result = await self.turns_table.find(
{
NameConfig._id_column: ctx_id,
NameConfig._key_column: {"$in": list(keys)},
field_name: {"$exists": True, "$ne": None},
},
[NameConfig._key_column, field_name],
).to_list(None)
return [(item[NameConfig._key_column], item[field_name]) for item in result]
[docs]
async def _clear_all(self) -> None:
await gather(self.main_table.delete_many({}), self.turns_table.delete_many({}))