autointent.modules.embedding.LogregAimedEmbedding#
- class autointent.modules.embedding.LogregAimedEmbedding(embedder_config, cv=3)#
Bases:
autointent.modules.base.BaseEmbedding
Module for configuring embeddings optimized for linear classification.
The main purpose of this module is to be used at embedding node for optimizing embedding configuration using its logreg classification quality as a sort of proxy metric.
- Parameters:
embedder_config (autointent.configs.EmbedderConfig | str | dict[str, Any]) – Config of the embedder used for creating embeddings
cv (pydantic.PositiveInt) – Number of folds used in LogisticRegressionCV
Examples:#
from autointent.modules.embedding import LogregAimedEmbedding utterances = ["bye", "how are you?", "good morning"] labels = [0, 1, 1] retrieval = LogregAimedEmbedding( embedder_config="sergeyzh/rubert-tiny-turbo", cv=2 ) retrieval.fit(utterances, labels)
- name = 'logreg_embedding'#
Name of the module.
- supports_multiclass = True#
Whether the module supports multiclass classification
- supports_multilabel = True#
Whether the module supports multilabel classification
- supports_oos = False#
Whether the module supports oos data
- embedder_config#
- cv = 3#
- classmethod from_context(context, embedder_config, cv=3)#
Create a LogregAimedEmbedding instance using a Context object.
- Parameters:
context (autointent.Context) – Context containing configurations and utilities
cv (pydantic.PositiveInt) – Number of folds used in LogisticRegressionCV
embedder_config (autointent.configs.EmbedderConfig | str) – Config of the embedder to use
- Return type:
- clear_cache()#
Clear embedder from memory.
- Return type:
None
- fit(utterances, labels)#
Train the logistic regression model using the provided utterances and labels.
- score_ho(context, metrics)#
Evaluate the embedding model using specified metric functions.
- score_cv(context, metrics)#
Evaluate the embedding model using specified metric functions.
- get_assets()#
Get the classifier artifacts for this module.
- Returns:
EmbeddingArtifact object containing embedder information
- Return type: