Source code for pyrelational.model_managers.abstract_model_manager

import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union

from torch.utils.data import DataLoader

ModelType = TypeVar("ModelType")
E = TypeVar("E")


[docs] class ModelManager(ABC, Generic[ModelType, E]): """ Abstract class used to wrap models to interact with the Strategy. It handles model instantiation at each iteration, training, testing, and queries. """ def __init__( self, model_class: Type[ModelType], model_config: Union[str, Dict[str, Any]], trainer_config: Union[str, Dict[str, Any]], ): """ :param model_class: a model constructor (e.g. torch.nn.Linear) :param model_config: a dictionary containing the config required to instantiate a model form the model_class (e.g. {in_features=100, out_features=34, bias=True, device=None, dtype=None} for a torch.nn.Linear constructor) :param trainer_config: a dictionary containing the config required to instantiate the trainer module/function """ super(ModelManager, self).__init__() self.model_class = model_class self.model_config = json.load(open(model_config, "r")) if isinstance(model_config, str) else model_config self._current_model: Optional[E] = None self.trainer_config = ( json.load(open(trainer_config, "r")) if isinstance(trainer_config, str) else trainer_config ) def _init_model(self) -> ModelType: """ Initialise model instance(s). :return: an instance of self.model_class based on self.model_config """ return self.model_class(**self.model_config)
[docs] def reset(self) -> None: """Reset stored _current_model.""" self._current_model = None
[docs] def is_trained(self) -> bool: """Check if model was trained.""" return self._current_model is not None
[docs] @abstractmethod def train(self, train_loader: DataLoader[Any], valid_loader: Optional[DataLoader[Any]] = None) -> None: """ Run train routine. :param train_loader: pytorch dataloader for training set :param valid_loader: pytorch dataloader for validation set """ pass
[docs] @abstractmethod def test(self, loader: DataLoader[Any]) -> Dict[str, float]: """ Run test routine. :param loader: pytorch dataloader for test set :return: performance metrics """ pass
[docs] def __call__(self, loader: DataLoader[Any]) -> Any: """ Call method to output model predictions :param loader: pytorch dataloader :return: model predictions for each sample in dataloader """ pass
def __str__(self) -> str: return f"{self.__class__.__name__}({self.model_class.__name__})"