autointent.modules.scoring.KNNScorer#
- class autointent.modules.scoring.KNNScorer(embedder_name, k, weights='distance', db_dir=None, embedder_device='cpu', batch_size=32, max_length=None, embedder_use_cache=False)#
Bases:
autointent.modules.abc.ScoringModule
K-nearest neighbors (KNN) scorer for intent classification.
This module uses a vector index to retrieve nearest neighbors for query utterances and applies a weighting strategy to compute class probabilities.
- Variables:
weights – Weighting strategy used for scoring.
_vector_index – VectorIndex instance for neighbor retrieval.
name – Name of the scorer, defaults to “knn”.
prebuilt_index – Flag indicating if the vector index is prebuilt.
- Parameters:
Examples#
from autointent.modules.scoring import KNNScorer utterances = ["hello", "how are you?"] labels = [0, 1] scorer = KNNScorer( embedder_name="sergeyzh/rubert-tiny-turbo", k=5, db_dir=db_dir, ) scorer.fit(utterances, labels) test_utterances = ["hi", "what's up?"] probabilities = scorer.predict(test_utterances) print(probabilities) # Outputs predicted class probabilities for the utterances
[[0.67297815 0.32702185] [0.44031678 0.55968322]]
- weights: autointent.custom_types.WEIGHT_TYPES#
- name = 'knn'#
- embedder_name#
- k#
- embedder_device = 'cpu'#
- batch_size = 32#
- embedder_use_cache = False#
- property db_dir: str#
Get the database directory for the vector index.
- Returns:
Path to the database directory.
- Return type:
- classmethod from_context(context, k, weights, embedder_name=None)#
Create a KNNScorer instance using a Context object.
- Parameters:
context (autointent.context.Context) – Context containing configurations and utilities.
k (int) – Number of closest neighbors to consider during inference.
weights (autointent.custom_types.WEIGHT_TYPES) – Weighting strategy for scoring.
embedder_name (str | None) – Name of the embedder, or None to use the best embedder.
- Returns:
Initialized KNNScorer instance.
- Return type:
- fit(utterances, labels)#
Fit the scorer by training or loading the vector index.
- Parameters:
- Raises:
ValueError – If the vector index mismatches the provided utterances.
- Return type:
None
- predict(utterances)#
Predict class probabilities for the given utterances.
- predict_with_metadata(utterances)#
Predict class probabilities along with metadata for the given utterances.
- clear_cache()#
Clear cached data in memory used by the vector index.
- Return type:
None