Source code for pyrelational.strategies.task_agnostic.representative_sampling_strategy

"""Representative sampling based active learning strategy
"""

from typing import Any, List, Union

import numpy as np
import torch
from sklearn.base import ClusterMixin

from pyrelational.data_managers import DataManager
from pyrelational.informativeness import representative_sampling
from pyrelational.strategies.abstract_strategy import Strategy


[docs] class RepresentativeSamplingStrategy(Strategy): """Representative sampling based active learning strategy""" def __init__( self, clustering_method: Union[str, ClusterMixin] = "KMeans", **clustering_kwargs: Any, ): """ :param clustering_method: name, or instantiated class, of the clustering method to use :param clustering_kwargs: arguments to be passed to instantiate clustering class if a string is passed to clustering_method """ super(RepresentativeSamplingStrategy, self).__init__() self.clustering_method = clustering_method self.clustering_kwargs = clustering_kwargs
[docs] def __call__( self, data_manager: DataManager, num_annotate: int, ) -> List[int]: """ Call function which identifies samples which need to be labelled :param data_manager: A pyrelational data manager which keeps track of what has been labelled and creates data loaders for active learning :param num_annotate: number of samples to annotate :return: list of indices to annotate """ unlabelled_features = torch.stack(data_manager.get_sample_feature_vectors(data_manager.u_indices)) representative_samples = representative_sampling( unlabelled_features, num_annotate=num_annotate, clustering_method=self.clustering_method, **self.clustering_kwargs, ) return [data_manager.u_indices[i] for i in representative_samples]