autointent.modules.embedding.RetrievalEmbedding#
- class autointent.modules.embedding.RetrievalEmbedding(k, embedder_name, db_dir=None, embedder_device='cpu', batch_size=32, max_length=None, embedder_use_cache=False)#
Bases:
autointent.modules.abc.EmbeddingModule
Module for managing retrieval operations using a vector database.
RetrievalEmbedding provides methods for indexing, querying, and managing a vector database for tasks such as nearest neighbor retrieval.
- Variables:
vector_index – The vector index used for nearest neighbor retrieval.
name – Name of the module, defaults to “retrieval”.
- Parameters:
Examples#
from autointent.modules.embedding import RetrievalEmbedding utterances = ["bye", "how are you?", "good morning"] labels = [0, 1, 1] retrieval = RetrievalEmbedding( k=2, embedder_name="sergeyzh/rubert-tiny-turbo", db_dir=db_dir, ) retrieval.fit(utterances, labels) predictions = retrieval.predict(["how is the weather today?"]) print(predictions)
([[1, 1]], [[0.1525942087173462, 0.18616724014282227]], [['good morning', 'how are you?']])
- vector_index: autointent.context.vector_index_client.VectorIndex#
- name = 'retrieval'#
- embedder_name#
- embedder_device = 'cpu'#
- batch_size = 32#
- max_length = None#
- embedder_use_cache = False#
- classmethod from_context(context, k, embedder_name)#
Create a RetrievalEmbedding instance using a Context object.
- Parameters:
context (autointent.context.Context) – The context containing configurations and utilities.
k (int) – Number of nearest neighbors to retrieve.
embedder_name (str) – Name of the embedder to use.
- Returns:
Initialized RetrievalEmbedding instance.
- Return type:
- property db_dir: str#
Get the directory for the vector database.
- Returns:
Path to the database directory.
- Return type:
- fit(utterances, labels)#
Fit the vector index using the provided utterances and labels.
- score(context, split, metric_fn)#
Evaluate the embedding model using a specified metric function.
- Parameters:
context (autointent.context.Context) – The context containing test data and labels.
split (Literal['validation', 'test']) – Target split
metric_fn (autointent.metrics.RetrievalMetricFn) – Function to compute the retrieval metric.
- Returns:
Computed metric score.
- Return type:
- get_assets()#
Get the retriever artifacts for this module.
- Returns:
A RetrieverArtifact object containing embedder information.
- Return type:
- clear_cache()#
Clear cached data in memory used by the vector index.
- Return type:
None
- dump(path)#
Save the module’s metadata and vector index to a specified directory.
- Parameters:
path (str) – Path to the directory where assets will be dumped.
- Return type:
None
- load(path)#
Load the module’s metadata and vector index from a specified directory.
- Parameters:
path (str) – Path to the directory containing the dumped assets.
- Return type:
None
- predict(utterances)#
Predict the nearest neighbors for a list of utterances.
- Parameters:
utterances (list[str]) – List of utterances for which nearest neighbors are to be retrieved.
- Returns:
A tuple containing: - labels: List of retrieved labels for each utterance. - distances: List of distances to the nearest neighbors. - texts: List of retrieved text data corresponding to the neighbors.
- Return type:
tuple[list[list[int | list[int]]], list[list[float]], list[list[str]]]