autointent.modules.SklearnScorer#
- class autointent.modules.SklearnScorer(clf_name='LogisticRegression', embedder_config=None, **clf_args)#
Bases:
autointent.modules.base.BaseScorer
Scoring module for classification using sklearn classifiers.
This module uses embeddings generated from a transformer model to train chosen sklearn classifier for intent classification.
- Parameters:
Examples
>>> from autointent.modules.scoring import SklearnScorer >>> utterances = ["hello", "how are you?"] >>> labels = [0, 1] >>> scorer = SklearnScorer( ... clf_name="LogisticRegression", ... embedder_config="sergeyzh/rubert-tiny-turbo", ... ) >>> scorer.fit(utterances, labels) >>> test_utterances = ["hi", "what's up?"] >>> probabilities = scorer.predict(test_utterances)
- name = 'sklearn'#
Name of the module to reference in search space configuration.
- supports_multilabel = True#
Whether the module supports multilabel classification
- supports_multiclass = True#
Whether the module supports multiclass classification
- embedder_config#
- clf_name = 'LogisticRegression'#
- classmethod from_context(context, clf_name='LogisticRegression', embedder_config=None, **clf_args)#
Create a SklearnScorer instance using a Context object.
- Parameters:
context (autointent.Context) – Context containing configurations and utilities
clf_name (str) – Name of the sklearn classifier to use
embedder_config (autointent.configs.EmbedderConfig | str | None) – Config of the embedder, or None to use the best embedder
**clf_args (dict[str, float | str | bool]) – Arguments for the chosen sklearn classifier
- Return type:
typing_extensions.Self
- get_implicit_initialization_params()#
Return default params used in
__init__
method.Some parameters of the module may be inferred using context rather from
__init__
method. But they need to be logged for reproducibility during loading from disk.
- fit(utterances, labels)#
Train the chosen sklearn classifier.
- Parameters:
- Raises:
ValueError – If the vector index mismatches the provided utterances
- Return type:
None
- predict(utterances)#
Predict probabilities for the given utterances.
- clear_cache()#
Clear cached data in memory used by the embedder.
- Return type:
None