Source code for pyrelational.strategies.classification.marginal_confidence_strategy

"""
Active learning using marginal confidence uncertainty measure
between classes in the posterior predictive distribution to
choose which observations to propose to the oracle
"""

from torch import Tensor

from pyrelational.informativeness import classification_margin_confidence
from pyrelational.strategies.classification.abstract_classification_strategy import (
    ClassificationStrategy,
)


[docs] class MarginalConfidenceStrategy(ClassificationStrategy): """Implements Marginal Confidence Strategy whereby unlabelled samples are scored and queried based on the marginal confidence for classification scorer"""
[docs] def scoring_function(self, predictions: Tensor) -> Tensor: return classification_margin_confidence(predictions)