Source code for plismbench.models.lunit
"""Models from Lunit company."""
from __future__ import annotations
import numpy as np
import torch
from timm.models.vision_transformer import VisionTransformer
from torchvision import transforms
from plismbench.models.extractor import Extractor
from plismbench.models.utils import DEFAULT_DEVICE, prepare_module
from plismbench.utils.core import download_state_dict
[docs]
class LunitViTS8(Extractor):
"""ViT-S/8 from Lunit available at (1).
.. note::
(1) https://github.com/lunit-io/benchmark-ssl-pathology/releases/tag/pretrained-weights
Parameters
----------
device: int | list[int] | None = DEFAULT_DEVICE,
Compute resources to use.
If None, will use all available GPUs.
If -1, extraction will run on CPU.
mixed_precision: bool = True
Whether to use mixed_precision.
"""
def __init__(
self,
device: int | list[int] | None = DEFAULT_DEVICE,
mixed_precision: bool = False,
):
super().__init__()
self.output_dim = 384
self.mixed_precision = mixed_precision
feature_extractor = VisionTransformer(
img_size=224,
patch_size=8,
embed_dim=384,
num_heads=6,
num_classes=0,
)
state_dict_path = download_state_dict(
url="https://github.com/lunit-io/benchmark-ssl-pathology/releases/download/pretrained-weights/dino_vit_small_patch8_ep200.torch",
name="lunit_vit_s8.pth",
)
state_dict = torch.load(state_dict_path, map_location="cpu")
feature_extractor.load_state_dict(state_dict, strict=False)
self.feature_extractor, self.device = prepare_module(
feature_extractor,
device,
self.mixed_precision,
)
if self.device is None:
self.feature_extractor = self.feature_extractor.module
@property # type: ignore
def transform(self) -> transforms.Compose:
"""Transform method to apply element wise."""
return transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=(0.70322989, 0.53606487, 0.66096631),
std=(0.21716536, 0.26081574, 0.20723464),
),
]
)
def __call__(self, images: torch.Tensor) -> np.ndarray:
"""Compute and return features.
Parameters
----------
images: torch.Tensor
Input of size (n_tiles, n_channels, dim_x, dim_y).
Returns
-------
torch.Tensor: Tensor of size (n_tiles, features_dim).
"""
features = self.feature_extractor(images.to(self.device))
return features.cpu().numpy()