autointent.modules.decision.ThresholdDecision#

class autointent.modules.decision.ThresholdDecision(thresh)#

Bases: autointent.modules.abc.DecisionModule

Threshold predictor module.

ThresholdDecision uses a predefined threshold (or array of thresholds) to predict labels for single-label or multi-label classification tasks.

Variables:
  • metadata_dict_name – Filename for saving metadata to disk.

  • multilabel – If True, the model supports multi-label classification.

  • n_classes – Number of classes in the dataset.

  • tags – Tags for predictions (if any).

  • name – Name of the predictor, defaults to “adaptive”.

Parameters:

thresh (float | numpy.typing.NDArray[Any])

Examples#

Single-label classification#

from autointent.modules import ThresholdDecision
import numpy as np
scores = np.array([[0.2, 0.8], [0.6, 0.4], [0.1, 0.9]])
labels = [1, 0, 1]
threshold = 0.5
predictor = ThresholdDecision(thresh=threshold)
predictor.fit(scores, labels)
test_scores = np.array([[0.3, 0.7], [0.5, 0.5]])
predictions = predictor.predict(test_scores)
print(predictions)
[1 0]

Multi-label classification#

labels = [[1, 0], [0, 1], [1, 1]]
predictor = ThresholdDecision(thresh=[0.5, 0.5])
predictor.fit(scores, labels)
test_scores = np.array([[0.3, 0.7], [0.6, 0.4]])
predictions = predictor.predict(test_scores)
print(predictions)
[[0 1]
 [1 0]]
metadata: ThresholdDecisionDumpMetadata#
multilabel: bool#
n_classes: int#
tags: list[autointent.schemas.Tag] | None#
name = 'threshold'#
thresh#
classmethod from_context(context, thresh=0.5)#

Initialize from context.

Parameters:
Return type:

ThresholdDecision

fit(scores, labels, tags=None)#

Fit the model.

Parameters:
  • scores (numpy.typing.NDArray[Any]) – Scores to fit

  • labels (list[autointent.custom_types.LabelType]) – Labels to fit

  • tags (list[autointent.schemas.Tag] | None) – Tags to fit

Return type:

None

predict(scores)#

Predict the best score.

Parameters:

scores (numpy.typing.NDArray[Any]) – Scores to predict

Return type:

numpy.typing.NDArray[Any]

dump(path)#

Dump the metadata.

Parameters:

path (str) – Path to dump

Return type:

None

load(path)#

Load the metadata.

Parameters:

path (str) – Path to load

Return type:

None