Source code for pyrelational.strategies.task_agnostic.relative_distance_strategy

from typing import List

import torch

from pyrelational.data_managers import DataManager
from pyrelational.informativeness import relative_distance
from pyrelational.strategies.abstract_strategy import Strategy


[docs] class RelativeDistanceStrategy(Strategy): """Diversity sampling based active learning strategy.""" def __init__(self, metric: str = "euclidean"): """ :param metric: Name of distance metric to use. This should be supported by scikit-learn pairwise_distances function. """ super(RelativeDistanceStrategy, self).__init__() self.metric = metric
[docs] def __call__(self, num_annotate: int, data_manager: DataManager) -> List[int]: """ Call function which identifies samples which need to be labelled :param num_annotate: number of samples to annotate :param data_manager: A pyrelational data manager which keeps track of what has been labelled and creates data loaders for active learning :return: list of indices to annotate """ scores = relative_distance( data_manager.get_unlabelled_loader(), data_manager.get_labelled_loader(), metric=self.metric ) ixs = torch.argsort(scores, descending=True).tolist() return [data_manager.u_indices[i] for i in ixs[:num_annotate]]