"""Models from Kaiko AI company."""
from __future__ import annotations
import numpy as np
import torch
from torchvision import transforms
from transformers import AutoModel
from plismbench.models.extractor import Extractor
from plismbench.models.utils import DEFAULT_DEVICE, prepare_module
[docs]
class KaikoViTBase(Extractor):
"""Kaiko ViT-Base model available on Pytorch Hub (1-2).
.. note::
(1) kaiko. ai, Aben, N., de Jong, E. D., Gatopoulos, I., Känzig, N., Karasikov, M., Lagré, A., Moser, R., van Doorn, J., & Tang, F. (2024). Towards large-scale training of pathology foundation models. arXiv. https://arxiv.org/abs/2404.15217
(2) https://github.com/kaiko-ai/towards_large_pathology_fms
Parameters
----------
device: int | list[int] | None = DEFAULT_DEVICE,
Compute resources to use.
If None, will use all available GPUs.
If -1, extraction will run on CPU.
mixed_precision: bool = True
Whether to use mixed_precision.
"""
def __init__(
self,
device: int | list[int] | None = DEFAULT_DEVICE,
mixed_precision: bool = False,
):
super().__init__()
self.output_dim = 768
self.mixed_precision = mixed_precision
feature_extractor = torch.hub.load(
"kaiko-ai/towards_large_pathology_fms",
"vitb8",
trust_repo=True,
verbose=True,
)
self.feature_extractor, self.device = prepare_module(
feature_extractor,
device,
self.mixed_precision,
)
if self.device is None:
self.feature_extractor = self.feature_extractor.module
@property # type: ignore
def transform(self) -> transforms.Compose:
"""Transform method to apply element wise."""
return transforms.Compose(
[
transforms.ToTensor(), # swap axes and normalize
transforms.Normalize(
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5),
),
]
)
def __call__(self, images: torch.Tensor) -> np.ndarray:
"""Compute and return features.
Parameters
----------
images: torch.Tensor
Input of size (n_tiles, n_channels, dim_x, dim_y).
Returns
-------
torch.Tensor: Tensor of size (n_tiles, features_dim).
"""
features = self.feature_extractor(images.to(self.device))
return features.cpu().numpy()
[docs]
class KaikoViTLarge(Extractor):
"""Kaiko ViT-Large model available on Pytorch Hub (1-2).
.. note::
(1) kaiko. ai, Aben, N., de Jong, E. D., Gatopoulos, I., Känzig, N., Karasikov, M., Lagré, A., Moser, R., van Doorn, J., & Tang, F. (2024). Towards large-scale training of pathology foundation models. arXiv. https://arxiv.org/abs/2404.15217
(2) https://github.com/kaiko-ai/towards_large_pathology_fms
Parameters
----------
device: int | list[int] | None = DEFAULT_DEVICE,
Compute resources to use.
If None, will use all available GPUs.
If -1, extraction will run on CPU.
mixed_precision: bool = True
Whether to use mixed_precision.
"""
def __init__(
self,
device: int | list[int] | None = DEFAULT_DEVICE,
mixed_precision: bool = False,
):
super().__init__()
self.output_dim = 1024
self.mixed_precision = mixed_precision
feature_extractor = torch.hub.load(
"kaiko-ai/towards_large_pathology_fms",
"vitl14",
trust_repo=True,
verbose=True,
)
self.feature_extractor, self.device = prepare_module(
feature_extractor,
device,
self.mixed_precision,
)
if self.device is None:
self.feature_extractor = self.feature_extractor.module
@property # type: ignore
def transform(self) -> transforms.Compose:
"""Transform method to apply element wise."""
return transforms.Compose(
[
transforms.ToTensor(), # swap axes and normalize
transforms.Normalize(
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5),
),
]
)
def __call__(self, images: torch.Tensor) -> np.ndarray:
"""Compute and return features.
Parameters
----------
images: torch.Tensor
Input of size (n_tiles, n_channels, dim_x, dim_y).
Returns
-------
torch.Tensor: Tensor of size (n_tiles, features_dim).
"""
features = self.feature_extractor(images.to(self.device))
return features.cpu().numpy()
[docs]
class Midnight12k(Extractor):
"""Midnight-12k model developped by Kaiko AI available on Hugging-Face (1).
.. note::
(1) https://huggingface.co/kaiko-ai/midnight
Parameters
----------
device: int | list[int] | None = DEFAULT_DEVICE,
Compute resources to use.
If None, will use all available GPUs.
If -1, extraction will run on CPU.
mixed_precision: bool = True
Whether to use mixed_precision.
"""
def __init__(
self,
device: int | list[int] | None = DEFAULT_DEVICE,
mixed_precision: bool = False,
):
super().__init__()
self.output_dim = 3072
self.mixed_precision = mixed_precision
feature_extractor = AutoModel.from_pretrained("kaiko-ai/midnight")
self.feature_extractor, self.device = prepare_module(
feature_extractor,
device,
self.mixed_precision,
)
if self.device is None:
self.feature_extractor = self.feature_extractor.module
@property # type: ignore
def transform(self) -> transforms.Compose:
"""Transform method to apply element wise."""
return transforms.Compose(
[
transforms.ToTensor(), # swap axes and normalize
transforms.Normalize(
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5),
),
]
)
def __call__(self, images: torch.Tensor) -> np.ndarray:
"""Compute and return features.
Parameters
----------
images: torch.Tensor
Input of size (n_tiles, n_channels, dim_x, dim_y).
Returns
-------
torch.Tensor: Tensor of size (n_tiles, features_dim).
"""
last_hidden_state = self.feature_extractor(images.to(self.device))
class_token = last_hidden_state[:, 0]
patch_tokens = last_hidden_state[:, 1:]
features = torch.cat([class_token, patch_tokens.mean(1)], dim=-1)
return features.cpu().numpy()