"""
This module contains methods for scoring samples based on distances between
featurization of samples. These scorers are task-agnostic.
"""
import inspect
import logging
from typing import Any, Callable, List, Optional, Union, get_args
import numpy as np
import sklearn.cluster as sklust
import torch
from numpy.typing import NDArray
from sklearn.base import ClusterMixin
from sklearn.metrics import pairwise_distances_argmin, pairwise_distances_argmin_min
from torch import Tensor
from torch.utils.data import DataLoader
logging.basicConfig()
logger = logging.getLogger()
Array = Union[Tensor, NDArray[Any], List[Any]]
[docs]
def relative_distance(
query_set: Union[Array, DataLoader[Any]],
reference_set: Union[Array, DataLoader[Any]],
metric: Optional[Union[str, Callable[..., Any]]] = "euclidean",
axis: int = 1,
) -> Tensor:
"""
Function that return the minimum distance, according to input metric, from each sample in the query_set to the
samples in the reference set.
:param query_set: input containing the features of samples in the queryable pool. query set should either be an
array-like object or a pytorch dataloader whose first element in each bactch is a featurisation of the samples
in the batch.
:param reference_set: input containing the features of samples already queried samples against which the distances
are computed. reference set should either be an array-like object or a pytorch dataloader whose first element in
each bactch is a featurisation of the samples in the batch.
:param metric: defines the metric to be used to compute the distance. This should be supported by scikit-learn
pairwise_distances function.
:param axis: integer indicating which dimension the features are
:return: pytorch tensor of dimension the number of samples in query_set containing the minimum distance from each
sample to the reference set
"""
if isinstance(query_set, (Tensor, np.ndarray, list)):
query_set = np.array(query_set)
query_set = query_set.reshape((query_set.shape[0], -1))
if isinstance(reference_set, (Tensor, np.ndarray, list)):
reference_set = np.array(reference_set)
reference_set = reference_set.reshape((reference_set.shape[0], -1))
if isinstance(reference_set, np.ndarray) and isinstance(query_set, np.ndarray):
_, distances = pairwise_distances_argmin_min(query_set, reference_set, metric=metric, axis=axis)
elif isinstance(reference_set, np.ndarray) and isinstance(query_set, DataLoader):
distances = []
for q in query_set:
q = q[0].reshape((q[0].shape[0], -1))
distances.append(pairwise_distances_argmin_min(q, reference_set, metric=metric, axis=axis)[1])
distances = np.hstack(distances)
elif isinstance(reference_set, DataLoader) and isinstance(query_set, np.ndarray):
distances = []
for r in reference_set:
r = r[0].reshape((r[0].shape[0], -1))
distances.append(pairwise_distances_argmin_min(query_set, r, metric=metric, axis=axis)[1])
distances = np.min(np.vstack(distances), axis=0)
elif isinstance(reference_set, DataLoader) and isinstance(query_set, DataLoader):
distances = []
for q in query_set:
temp = []
q = q[0].reshape((q[0].shape[0], -1))
for r in reference_set:
r = r[0].reshape((r[0].shape[0], -1))
temp.append(pairwise_distances_argmin_min(q, r, metric=metric, axis=axis)[1])
distances.append(np.min(np.vstack(temp), axis=0))
distances = np.hstack(distances)
else:
raise TypeError("reference_set and query_set should either be an array_like structure or a pytorch DataLoader")
return torch.from_numpy(distances).float()
[docs]
def representative_sampling(
query_set: Union[Array, DataLoader[Any]],
num_annotate: int,
clustering_method: Union[str, ClusterMixin] = "KMeans",
**clustering_kwargs: Optional[Any],
) -> List[int]:
"""
Function that selects representative samples of the query set. Representative selection relies on clustering
algorithms in scikit-learn.
:param query_set: input containing the features of samples in the queryable pool. query set should either be an
array-like object or a pytorch dataloader whose first element in each bactch is a featurisation of the samples
in the batch
:param num_annotate: number of representative samples to identify
:param clustering_method: name, or instantiated class, of the clustering method to use
:param clustering_kwargs: arguments to be passed to instantiate clustering class if a string is passed to
clustering_method
:return: array-like containing the indices of the representative samples identified
"""
if isinstance(query_set, DataLoader):
out = []
for q in query_set:
out.append(q[0].reshape((q[0].shape[0], -1)))
query_set = torch.cat(out, 0)
query_set = np.array(query_set)
if num_annotate >= query_set.shape[0]: # if there are less samples than sought queries, return everything
ret: List[int] = np.arange(query_set.shape[0]).tolist()
return ret
if isinstance(clustering_method, str) and hasattr(sklust, clustering_method):
clustering_method = getattr(sklust, clustering_method)
if "n_clusters" in inspect.getfullargspec(clustering_method).args:
clustering_kwargs["n_clusters"] = num_annotate
clustering_cls = clustering_method(**clustering_kwargs)
elif isinstance(clustering_method, str):
raise ValueError(f"{clustering_method} is not part of the sklearn package")
elif isinstance(clustering_method, ClusterMixin):
clustering_cls = clustering_method
else:
raise TypeError(
"""clustering_method argument type not supported, it should be either a string pointing to a method of
sklearn or an instantiated clustering algorithm subclassing sklearn ClusterMixin"""
)
lbls = clustering_cls.fit_predict(query_set)
if hasattr(clustering_cls, "cluster_centers_indices_"):
indices: List[int] = clustering_cls.cluster_centers_indices_
representative_samples = indices
elif hasattr(clustering_cls, "cluster_centers_"):
representative_samples = get_closest_query_to_centroids(clustering_cls.cluster_centers_, query_set, lbls)
else:
logger.warning(
"""Clustering method does not return centroids to identify closest samples,
returning random sample from each cluster"""
)
representative_samples = get_random_query_from_cluster(lbls)
num_samples = min(num_annotate, len(representative_samples))
ret = np.random.choice( # in case there are more that num_annotates samples
representative_samples,
size=(num_samples,),
replace=False,
).tolist()
return ret
[docs]
def get_closest_query_to_centroids(
centroids: NDArray[np.float_],
query: NDArray[np.float_],
cluster_assignment: NDArray[np.int_],
) -> List[int]:
"""
Find the closest sample in query to centroids.
:param centroids: array containing centroids
:param query: array containing query samples
:param cluster_assignment: indicate what cluster each query sample is associated with
:return: list of indices of query samples
"""
out = []
for i in np.unique(cluster_assignment):
ixs = np.where(cluster_assignment == i)[0]
centroid = centroids[i].reshape(1, -1)
subquery = query[ixs]
j = pairwise_distances_argmin(centroid, subquery).item()
out.append(ixs[j])
return out
[docs]
def get_random_query_from_cluster(cluster_assignment: NDArray[np.int_]) -> List[int]:
"""
Get random indices drawn from each cluster.
:param cluster_assignment: array indicating what cluster each sample is associated with.
:return: list of indices of query samples
"""
out = []
for i in np.unique(cluster_assignment):
ixs = np.where(cluster_assignment == i)[0]
out.append(np.random.choice(ixs))
return out