Source code for plismbench.models.kaiko_ai

"""Models from Kaiko AI company."""

from __future__ import annotations

import numpy as np
import torch
from torchvision import transforms
from transformers import AutoModel

from plismbench.models.extractor import Extractor
from plismbench.models.utils import DEFAULT_DEVICE, prepare_module


[docs] class KaikoViTBase(Extractor): """Kaiko ViT-Base model available on Pytorch Hub (1-2). .. note:: (1) kaiko. ai, Aben, N., de Jong, E. D., Gatopoulos, I., Känzig, N., Karasikov, M., Lagré, A., Moser, R., van Doorn, J., & Tang, F. (2024). Towards large-scale training of pathology foundation models. arXiv. https://arxiv.org/abs/2404.15217 (2) https://github.com/kaiko-ai/towards_large_pathology_fms Parameters ---------- device: int | list[int] | None = DEFAULT_DEVICE, Compute resources to use. If None, will use all available GPUs. If -1, extraction will run on CPU. mixed_precision: bool = True Whether to use mixed_precision. """ def __init__( self, device: int | list[int] | None = DEFAULT_DEVICE, mixed_precision: bool = False, ): super().__init__() self.output_dim = 768 self.mixed_precision = mixed_precision feature_extractor = torch.hub.load( "kaiko-ai/towards_large_pathology_fms", "vitb8", trust_repo=True, verbose=True, ) self.feature_extractor, self.device = prepare_module( feature_extractor, device, self.mixed_precision, ) if self.device is None: self.feature_extractor = self.feature_extractor.module @property # type: ignore def transform(self) -> transforms.Compose: """Transform method to apply element wise.""" return transforms.Compose( [ transforms.ToTensor(), # swap axes and normalize transforms.Normalize( mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), ), ] ) def __call__(self, images: torch.Tensor) -> np.ndarray: """Compute and return features. Parameters ---------- images: torch.Tensor Input of size (n_tiles, n_channels, dim_x, dim_y). Returns ------- torch.Tensor: Tensor of size (n_tiles, features_dim). """ features = self.feature_extractor(images.to(self.device)) return features.cpu().numpy()
[docs] class KaikoViTLarge(Extractor): """Kaiko ViT-Large model available on Pytorch Hub (1-2). .. note:: (1) kaiko. ai, Aben, N., de Jong, E. D., Gatopoulos, I., Känzig, N., Karasikov, M., Lagré, A., Moser, R., van Doorn, J., & Tang, F. (2024). Towards large-scale training of pathology foundation models. arXiv. https://arxiv.org/abs/2404.15217 (2) https://github.com/kaiko-ai/towards_large_pathology_fms Parameters ---------- device: int | list[int] | None = DEFAULT_DEVICE, Compute resources to use. If None, will use all available GPUs. If -1, extraction will run on CPU. mixed_precision: bool = True Whether to use mixed_precision. """ def __init__( self, device: int | list[int] | None = DEFAULT_DEVICE, mixed_precision: bool = False, ): super().__init__() self.output_dim = 1024 self.mixed_precision = mixed_precision feature_extractor = torch.hub.load( "kaiko-ai/towards_large_pathology_fms", "vitl14", trust_repo=True, verbose=True, ) self.feature_extractor, self.device = prepare_module( feature_extractor, device, self.mixed_precision, ) if self.device is None: self.feature_extractor = self.feature_extractor.module @property # type: ignore def transform(self) -> transforms.Compose: """Transform method to apply element wise.""" return transforms.Compose( [ transforms.ToTensor(), # swap axes and normalize transforms.Normalize( mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), ), ] ) def __call__(self, images: torch.Tensor) -> np.ndarray: """Compute and return features. Parameters ---------- images: torch.Tensor Input of size (n_tiles, n_channels, dim_x, dim_y). Returns ------- torch.Tensor: Tensor of size (n_tiles, features_dim). """ features = self.feature_extractor(images.to(self.device)) return features.cpu().numpy()
[docs] class Midnight12k(Extractor): """Midnight-12k model developped by Kaiko AI available on Hugging-Face (1). .. note:: (1) https://huggingface.co/kaiko-ai/midnight Parameters ---------- device: int | list[int] | None = DEFAULT_DEVICE, Compute resources to use. If None, will use all available GPUs. If -1, extraction will run on CPU. mixed_precision: bool = True Whether to use mixed_precision. """ def __init__( self, device: int | list[int] | None = DEFAULT_DEVICE, mixed_precision: bool = False, ): super().__init__() self.output_dim = 3072 self.mixed_precision = mixed_precision feature_extractor = AutoModel.from_pretrained("kaiko-ai/midnight") self.feature_extractor, self.device = prepare_module( feature_extractor, device, self.mixed_precision, ) if self.device is None: self.feature_extractor = self.feature_extractor.module @property # type: ignore def transform(self) -> transforms.Compose: """Transform method to apply element wise.""" return transforms.Compose( [ transforms.ToTensor(), # swap axes and normalize transforms.Normalize( mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), ), ] ) def __call__(self, images: torch.Tensor) -> np.ndarray: """Compute and return features. Parameters ---------- images: torch.Tensor Input of size (n_tiles, n_channels, dim_x, dim_y). Returns ------- torch.Tensor: Tensor of size (n_tiles, features_dim). """ last_hidden_state = self.feature_extractor(images.to(self.device)) class_token = last_hidden_state[:, 0] patch_tokens = last_hidden_state[:, 1:] features = torch.cat([class_token, patch_tokens.mean(1)], dim=-1) return features.cpu().numpy()