Source code for plismbench.models.hkust

"""Models from Hong Kong University of Science and Technology."""

from __future__ import annotations

import re

import numpy as np
import timm
import torch
from torchvision import transforms

from plismbench.models.extractor import Extractor
from plismbench.models.utils import DEFAULT_DEVICE, prepare_module
from plismbench.utils.core import download_state_dict


def _convert_state_dict(state_dict: dict) -> dict:
    """Rename state dict keys to match timm's format."""
    state_dict = {
        re.sub(r"blocks\.\d+\.(\d+)", r"blocks.\1", key.replace("backbone.", "")): value
        for key, value in state_dict.items()
    }
    remove_keys = ["mask_token"] + [
        key for key in state_dict.keys() if "dino_head" in key
    ]
    for key in remove_keys:
        state_dict.pop(key)
    return state_dict


[docs] class GPFM(Extractor): """GPFM model developped by HKUST (1). .. note:: (1) Ma, J., Guo, Z., Zhou, F., Wang, Y., Xu, Y., et al. (2024). Towards a generalizable pathology foundation model via unified knowledge distillation (arXiv No. 2407.18449). arXiv. https://arxiv.org/abs/2407.18449 Parameters ---------- device: int | list[int] | None = DEFAULT_DEVICE, Compute resources to use. If None, will use all available GPUs. If -1, extraction will run on CPU. mixed_precision: bool = True Whether to use mixed_precision. """ def __init__( self, device: int | list[int] | None = DEFAULT_DEVICE, mixed_precision: bool = False, ): super().__init__() self.output_dim = 1024 self.mixed_precision = mixed_precision _state_dict_path = download_state_dict( url="https://github.com/birkhoffkiki/GPFM/releases/download/ckpt/GPFM.pth", name="GPFM.pth", ) _state_dict = torch.load(_state_dict_path, map_location="cpu") state_dict = _convert_state_dict(_state_dict["teacher"]) feature_extractor = timm.create_model( model_name="vit_large_patch14_dinov2", pretrained=True, pretrained_cfg={ "state_dict": state_dict, "num_classes": 0, }, img_size=224, patch_size=14, init_values=1e-5, qkv_bias=True, dynamic_img_size=True, ) self.feature_extractor, self.device = prepare_module( feature_extractor, device, self.mixed_precision, ) if self.device is None: self.feature_extractor = self.feature_extractor.module @property # type: ignore def transform(self) -> transforms.Compose: """Transform method to apply element wise.""" return transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) ), ] ) def __call__(self, images: torch.Tensor) -> np.ndarray: """Compute and return features. Parameters ---------- images: torch.Tensor Input of size (n_tiles, n_channels, dim_x, dim_y). Returns ------- torch.Tensor: Tensor of size (n_tiles, features_dim). """ features = self.feature_extractor(images.to(self.device)) return features.cpu().numpy()