Source code for pyrelational.datasets.benchmark_datamanager

"""Utility to create datamanagers corresponding to different AL tasks
"""

import random
from collections import defaultdict
from typing import Any, List

import numpy as np
from sklearn.metrics import pairwise_distances
from torch.utils.data import Dataset

from pyrelational.data_managers.data_manager import DataManager


[docs] def pick_one_sample_per_class(dataset: Dataset[Any], train_indices: List[int]) -> List[int]: """Utility function to randomly pick one sample per class in the training subset of dataset and return their index in the dataset. This is used for defining an initial state of the labelled subset in the active learning task :param dataset: input dataset :param train_indices: list or iterable with the indices corresponding to the training samples in the dataset """ class2idx = defaultdict(list) for idx in train_indices: idx_class = int(dataset[idx][1]) class2idx[idx_class].append(idx) class_reps = [] for idx_class in class2idx.keys(): random_class_idx = random.choice(class2idx[idx_class]) class_reps.append(random_class_idx) return class_reps
[docs] def create_warm_start(dataset: Dataset[Any], **dm_args: Any) -> DataManager: """Returns a datamanager with 10% randomly labelled data from the train indices. The rest of the observations in the training set comprise the unlabelled set of observations. We call this initialisation a 'warm start' AL task inspired by Konyushkova et al. (2017) This can be used both for classification and regression type datasets. From Ksenia Konyushkova, Raphael Sznitman, Pascal Fua 'Learning Active Learning from Data', NIPS 2017 :param dataset: A pytorch dataset in the style described pyrelational.datasets :param dm_args: kwargs for any additional keyword arguments to be passed into the initialisation of the datamanager. """ dm = DataManager(dataset, **dm_args) return dm
[docs] def create_classification_cold_start( dataset: Dataset[Any], train_indices: List[int], test_indices: List[Any], **dm_args: Any ) -> DataManager: """Returns an AL task for benchmarking classification datasets. The AL task will sample an example from each of the classes in the training subset of the data. Please note the current iteration does not utilise a validation set as described in the paper :param dataset: A pytorch dataset in the style described pyrelational.datasets :param train_indices: [int] indices corresponding to observations of dataset used for training set :param test_indices: [int] indices corresponding to observations of dataset used for holdout test set :param dm_args: kwargs for any additional keyword arguments to be passed into the initialisation of the datamanager. """ labelled_indices = pick_one_sample_per_class(dataset, train_indices) dm = DataManager( dataset, train_indices=train_indices, test_indices=test_indices, labelled_indices=labelled_indices, **dm_args ) return dm
[docs] def create_regression_cold_start( dataset: Dataset[Any], train_indices: List[int], test_indices: List[Any], **dm_args: Any ) -> DataManager: """Create data manager with 2 labelled data samples, where the data samples labelled are the pair that have the largest distance between them Please note the current iteration does not utilise a validation set as described in the paper :param dataset: A pytorch dataset in the style described pyrelational.datasets :param train_indices: [int] indices corresponding to observations of dataset used for training set :param test_indices: [int] indices corresponding to observations of dataset used for holdout test set :param dm_args: kwargs for any additional keyword arguments to be passed into the initialisation of the datamanager. """ # Find the two samples within the training subset that have the largest distance between them. pair_dists = pairwise_distances(dataset[train_indices][:][0]) sample1_idx, sample2_idx = np.unravel_index(np.argmax(pair_dists, axis=None), pair_dists.shape) sample1_idx = train_indices[sample1_idx] # map to dataset index from local index sample2_idx = train_indices[sample2_idx] labelled_indices = [sample1_idx, sample2_idx] dm = DataManager( dataset, train_indices=train_indices, test_indices=test_indices, labelled_indices=labelled_indices, **dm_args, ) return dm