autointent.modules.decision.ThresholdDecision#
- class autointent.modules.decision.ThresholdDecision(thresh)#
Bases:
autointent.modules.abc.DecisionModule
Threshold predictor module.
ThresholdDecision uses a predefined threshold (or array of thresholds) to predict labels for single-label or multi-label classification tasks.
- Variables:
metadata_dict_name – Filename for saving metadata to disk.
multilabel – If True, the model supports multi-label classification.
n_classes – Number of classes in the dataset.
tags – Tags for predictions (if any).
name – Name of the predictor, defaults to “adaptive”.
- Parameters:
thresh (float | numpy.typing.NDArray[Any])
Examples#
Single-label classification#
from autointent.modules import ThresholdDecision import numpy as np scores = np.array([[0.2, 0.8], [0.6, 0.4], [0.1, 0.9]]) labels = [1, 0, 1] threshold = 0.5 predictor = ThresholdDecision(thresh=threshold) predictor.fit(scores, labels) test_scores = np.array([[0.3, 0.7], [0.5, 0.5]]) predictions = predictor.predict(test_scores) print(predictions)
[1 0]
Multi-label classification#
labels = [[1, 0], [0, 1], [1, 1]] predictor = ThresholdDecision(thresh=[0.5, 0.5]) predictor.fit(scores, labels) test_scores = np.array([[0.3, 0.7], [0.6, 0.4]]) predictions = predictor.predict(test_scores) print(predictions)
[[0 1] [1 0]]
- metadata: ThresholdDecisionDumpMetadata#
- tags: list[autointent.schemas.Tag] | None#
- name = 'threshold'#
- thresh#
- classmethod from_context(context, thresh=0.5)#
Initialize from context.
- Parameters:
context (autointent.Context) – Context
thresh (float | numpy.typing.NDArray[Any]) – Threshold
- Return type:
- fit(scores, labels, tags=None)#
Fit the model.
- Parameters:
scores (numpy.typing.NDArray[Any]) – Scores to fit
labels (list[autointent.custom_types.LabelType]) – Labels to fit
tags (list[autointent.schemas.Tag] | None) – Tags to fit
- Return type:
None
- predict(scores)#
Predict the best score.
- Parameters:
scores (numpy.typing.NDArray[Any]) – Scores to predict
- Return type:
numpy.typing.NDArray[Any]