Using your own datasets with PyRelationAL

The pyrelational.data_managers.data_manager.DataManager module enables users to integrate any pytorch Dataset into PyRelationAL easily. The module expects the full dataset, i.e. the union of labelled, unlabelled, validation (optional), and test sets. The indices of each sets should be provided to the class constructor that then proceeds to construct the subset Datasets object under the hood. Throughout the experiment, the data manager will keep track of indices and handle updates to the labelled/unlabelled pools of samples. For instance, using the Mnist dataset

import torch
from torchvision import datasets, transforms
from pyrelational.data_managers.data_manager import DataManager

 mnist_dataset = datasets.MNIST(
     "mnist_data",
     download=True,
     train=True,
     transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
 )
 train_ds, val_ds, test_ds = torch.utils.data.random_split(mnist_dataset, [50000, 5000, 5000])
 train_indices = train_ds.indices
 validation_indices = val_ds.indices
 test_indices = test_ds.indices
 labelled_indices = train_indices[:10000]

 data_manager = DataManager(
     mnist_dataset,
     train_indices=train_indices,
     labelled_indices=labelled_indices,
     validation_indices=validation_indices,
     test_indices=test_indices,
 )

Customizing dataloader

Users can customize the dataloaders in the same way as any pytorch dataloader by passing Pytorch DataLoader arguments to the data manager constructor, such as

data_manager = DataManager(
    mnist_dataset,
    train_indices=train_indices,
    labelled_indices=labelled_indices,
    validation_indices=validation_indices,
    test_indices=test_indices,
    loader_batch_size=10000,
    loader_num_workers=2,
    loader_shuffle=True,
)

Interacting with non-pytorch estimators

Importantly, this enables using pytorch Dataset and DataLoaders to interact with other libraries by taking advantage of the collate function. For instance, using the following collate function enables conversion to numpy array

def numpy_collate(batch):
    """Collate function for a Pytorch to Numpy DataLoader"""
    return [np.stack(el) for el in zip(*batch)]

data_manager = DataManager(
    mnist_dataset,
    train_indices=train_indices,
    labelled_indices=labelled_indices,
    validation_indices=validation_indices,
    test_indices=test_indices,
    loader_collate_fn=numpy_collate,
)

Returning single batch

In some instances, for instance when using Gaussian Processes or scikit-learn estimators, the dataloader should return the entire underlying dataset. This can be specified as such,

data_manager = DataManager(
    mnist_dataset,
    train_indices=train_indices,
    labelled_indices=labelled_indices,
    validation_indices=validation_indices,
    test_indices=test_indices,
    loader_batch_size="full",
)