Source code for plismbench.models.extractor

"""Core abstract method for feature extractors."""

from abc import ABC, abstractmethod
from typing import Callable

import numpy as np
import torch


[docs] class Extractor(ABC): """A base class for :mod:`plismbench` extractors.""" _feature_extractor: torch.nn.Module def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._transform = lambda x: x @property def feature_extractor(self) -> torch.nn.Module: """ Feature extractor module. Returns ------- feature_extractor: torch.nn.Module """ return self._feature_extractor @feature_extractor.setter def feature_extractor(self, feature_extractor_module: torch.nn.Module): """Set a new feature extractor module. Parameters ---------- feature_extractor_module: feature_extractor_module """ self._feature_extractor = feature_extractor_module @property def transform(self) -> Callable[[np.ndarray], torch.Tensor]: """ Transform method to apply element wise. Inputs should be np.ndarray. This function is applied on ``np.ndarray`` and not ``PIL.Image.Image`` as HuggingFace data is stored as numpy arrays for pickle checking purposes. If your model needs image resizing, then you will need to add a first ``transforms.ToPILImage()`` operation, then resizing and finally ``transforms.ToTensor()``. If your model is best working on images of shape 224x224, then no need for rescaling as PLISM tiles have 224x224 shapes. Default is identity. Returns ------- transform: Callable[[np.ndarray], torch.Tensor] """ return self._transform @transform.setter def transform(self, transform_function: Callable[[np.ndarray], torch.Tensor]): """Set a new transform function to the extractor. Parameters ---------- transform_function: Callable[[np.ndarray], Transformed] The transform function to be set for the extractor. """ self._transform = transform_function @abstractmethod def __call__(self, images: torch.Tensor) -> np.ndarray: """ Compute and return the MAP features. Parameters ---------- images: torch.Tensor Input of size (N_TILES, 3, DIM_X, DIM_Y). N_TILES=1 for an image, usually DIM_X = DIM_Y = 224. Returns ------- features : numpy.ndarray arrays of size (N_TILES, N_FEATURES) for an image """ raise NotImplementedError