autointent.modules.scoring.LinearScorer#
- class autointent.modules.scoring.LinearScorer(embedder_config=None, cv=3, seed=0)#
Bases:
autointent.modules.base.BaseScorer
Scoring module for linear classification using logistic regression.
This module uses embeddings generated from a transformer model to train a logistic regression classifier for intent classification.
- Parameters:
embedder_config (autointent.configs.EmbedderConfig | str | dict[str, Any] | None) – Config of the embedder model
cv (int) – Number of cross-validation folds, defaults to 3
n_jobs – Number of parallel jobs for cross-validation, defaults to None
seed (int) – Random seed for reproducibility, defaults to 0
Example:#
from autointent.modules import LinearScorer scorer = LinearScorer( embedder_config="sergeyzh/rubert-tiny-turbo", cv=2 ) utterances = ["hello", "goodbye", "allo", "sayonara"] labels = [0, 1, 0, 1] scorer.fit(utterances, labels) test_utterances = ["hi", "bye"] probabilities = scorer.predict(test_utterances) print(probabilities)
[[0.50000032 0.49999968] [0.50000032 0.49999968]]
- name = 'linear'#
Name of the module.
- supports_multiclass = True#
Whether the module supports multiclass classification
- supports_multilabel = True#
Whether the module supports multilabel classification
- cv = 3#
- seed = 0#
- embedder_config#
- classmethod from_context(context, embedder_config=None)#
Create a LinearScorer instance using a Context object.
- Parameters:
context (autointent.Context) – Context containing configurations and utilities
embedder_config (autointent.configs.EmbedderConfig | str | None) – Config of the embedder, or None to use the best embedder
- Return type:
- get_embedder_config()#
Get the name of the embedder.
- fit(utterances, labels)#
Train the logistic regression classifier.
- Parameters:
- Raises:
ValueError – If the vector index mismatches the provided utterances
- Return type:
None
- predict(utterances)#
Predict probabilities for the given utterances.
- clear_cache()#
Clear cached data in memory used by the embedder.
- Return type:
None