Source code for pyrelational.strategies.task_agnostic.relative_distance_strategy

"""Relative distance based active learning strategy."""

from typing import List

from pyrelational.batch_mode_samplers import TopKSampler
from pyrelational.data_managers import DataManager
from pyrelational.informativeness import RelativeDistanceScorer
from pyrelational.strategies.abstract_strategy import Strategy


[docs] class RelativeDistanceStrategy(Strategy): """Diversity sampling based active learning strategy.""" scorer: RelativeDistanceScorer def __init__(self, metric: str = "euclidean"): """Initialise the strategy with a distance metric. :param metric: Name of distance metric to use. This should be supported by scikit-learn pairwise_distances function. """ self.metric = metric super().__init__(RelativeDistanceScorer(metric=metric), TopKSampler())
[docs] def __call__(self, num_annotate: int, data_manager: DataManager) -> List[int]: """Identify samples which need to be labelled. :param num_annotate: number of samples to annotate :param data_manager: A pyrelational data manager which keeps track of what has been labelled and creates data loaders for active learning :return: list of indices to annotate """ scores = self.scorer(data_manager.get_unlabelled_loader(), data_manager.get_labelled_loader()) return self.sampler(scores, data_manager.u_indices, num_annotate)