Source code for pyrelational.strategies.regression.thompson_sampling_strategy

from torch import Tensor

from pyrelational.informativeness import regression_thompson_sampling
from pyrelational.strategies.regression.abstract_regression_strategy import (
    RegressionStrategy,
)


[docs] class ThompsonSamplingStrategy(RegressionStrategy): """Implements Thompson Sampling Strategy whereby unlabelled samples are scored and queried based on the thompson sampling scorer"""
[docs] def scoring_function(self, predictions: Tensor) -> Tensor: return regression_thompson_sampling(predictions)