"""Models from Mahmood Lab."""
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import timm
import torch
from conch.open_clip_custom import create_model_from_pretrained
from huggingface_hub import snapshot_download
from loguru import logger
from torchvision import transforms
from transformers import AutoModel
from plismbench.models.extractor import Extractor
from plismbench.models.utils import DEFAULT_DEVICE, prepare_device, prepare_module
[docs]
class UNI(Extractor):
"""UNI model developped by Mahmood Lab available on Hugging-Face (1).
.. note::
(1) https://huggingface.co/MahmoodLab/UNI
Parameters
----------
device: int | list[int] | None = DEFAULT_DEVICE,
Compute resources to use.
If None, will use all available GPUs.
If -1, extraction will run on CPU.
mixed_precision: bool = True
Whether to use mixed_precision.
"""
def __init__(
self,
device: int | list[int] | None = DEFAULT_DEVICE,
mixed_precision: bool = False,
):
super().__init__()
self.output_dim = 1024
self.mixed_precision = mixed_precision
timm_kwargs: dict[str, Any] = {
"init_values": 1e-5,
"dynamic_img_size": True,
}
feature_extractor = timm.create_model(
"hf-hub:MahmoodLab/uni", pretrained=True, **timm_kwargs
)
self.feature_extractor, self.device = prepare_module(
feature_extractor,
device,
self.mixed_precision,
)
if self.device is None:
self.feature_extractor = self.feature_extractor.module
@property # type: ignore
def transform(self) -> transforms.Compose:
"""Transform method to apply element wise."""
return transforms.Compose(
[
transforms.ToTensor(), # swap axes and normalize
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
),
]
)
def __call__(self, images: torch.Tensor) -> np.ndarray:
"""Compute and return features.
Parameters
----------
images: torch.Tensor
Input of size (n_tiles, n_channels, dim_x, dim_y).
Returns
-------
torch.Tensor: Tensor of size (n_tiles, features_dim).
"""
features = self.feature_extractor(images.to(self.device))
return features.cpu().numpy()
[docs]
class UNI2h(Extractor):
"""UNI2-h model developped by Mahmood Lab available on Hugging-Face (1).
.. note::
(1) https://huggingface.co/MahmoodLab/UNI2-h
Parameters
----------
device: int | list[int] | None = DEFAULT_DEVICE,
Compute resources to use.
If None, will use all available GPUs.
If -1, extraction will run on CPU.
mixed_precision: bool = True
Whether to use mixed_precision.
"""
def __init__(
self,
device: int | list[int] | None = DEFAULT_DEVICE,
mixed_precision: bool = False,
):
super().__init__()
self.output_dim = 1536
self.mixed_precision = mixed_precision
timm_kwargs: dict[str, Any] = {
"img_size": 224,
"patch_size": 14,
"depth": 24,
"num_heads": 24,
"init_values": 1e-5,
"embed_dim": 1536,
"mlp_ratio": 2.66667 * 2,
"num_classes": 0,
"no_embed_class": True,
"mlp_layer": timm.layers.SwiGLUPacked,
"act_layer": torch.nn.SiLU,
"reg_tokens": 8,
"dynamic_img_size": True,
}
feature_extractor = timm.create_model(
"hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs
)
self.feature_extractor, self.device = prepare_module(
feature_extractor,
device,
self.mixed_precision,
)
if self.device is None:
self.feature_extractor = self.feature_extractor.module
@property # type: ignore
def transform(self) -> transforms.Compose:
"""Transform method to apply element wise."""
return transforms.Compose(
[
transforms.ToTensor(), # swap axes and normalize
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
),
]
)
def __call__(self, images: torch.Tensor) -> np.ndarray:
"""Compute and return features.
Parameters
----------
images: torch.Tensor
Input of size (n_tiles, n_channels, dim_x, dim_y).
Returns
-------
torch.Tensor: Tensor of size (n_tiles, features_dim).
"""
features = self.feature_extractor(images.to(self.device))
return features.cpu().numpy()
[docs]
class CONCH(Extractor):
"""CONCH model developped by Mahmood Lab available on Hugging-Face (1).
.. note::
(1) https://huggingface.co/MahmoodLab/CONCH
Parameters
----------
device: int | list[int] | None = DEFAULT_DEVICE,
Compute resources to use.
If None, will use all available GPUs.
If -1, extraction will run on CPU.
mixed_precision: bool = True
Whether to use mixed_precision.
"""
def __init__(
self,
device: int | list[int] | None = DEFAULT_DEVICE,
mixed_precision: bool = False,
):
super().__init__()
self.output_dim = 512
self.mixed_precision = mixed_precision
checkpoint_dir = snapshot_download(repo_id="MahmoodLab/CONCH")
checkpoint_path = Path(checkpoint_dir) / "pytorch_model.bin"
feature_extractor, self.processor = create_model_from_pretrained(
"conch_ViT-B-16",
force_image_size=224,
checkpoint_path=str(checkpoint_path),
device=prepare_device(device),
)
self.feature_extractor, self.device = prepare_module(
feature_extractor,
device,
self.mixed_precision,
)
if self.device is None:
self.feature_extractor = self.feature_extractor.module
[docs]
def process(self, image) -> torch.Tensor:
"""Process input images."""
conch_input = self.processor(image)
return conch_input
@property # type: ignore
def transform(self) -> transforms.Lambda:
"""Transform method to apply element wise."""
return transforms.Lambda(self.process)
def __call__(self, images: torch.Tensor) -> np.ndarray:
"""Compute and return features.
Parameters
----------
images: torch.Tensor
Input of size (n_tiles, n_channels, dim_x, dim_y).
Returns
-------
torch.Tensor: Tensor of size (n_tiles, features_dim).
"""
features = self.feature_extractor.module.encode_image( # type: ignore
images.to(self.device), proj_contrast=False, normalize=False
)
return features.cpu().numpy()
[docs]
class CONCHv15(Extractor):
"""Conchv15 model available from TITAN on Hugging-Face (1).
.. note::
(1) https://huggingface.co/MahmoodLab/conchv1_5
"""
def __init__(
self,
device: int | list[int] | None = DEFAULT_DEVICE,
mixed_precision: bool = False,
):
super().__init__()
self.output_dim = 768
self.mixed_precision = mixed_precision
titan = AutoModel.from_pretrained("MahmoodLab/TITAN", trust_remote_code=True)
feature_extractor, _ = titan.return_conch()
self.feature_extractor, self.device = prepare_module(
feature_extractor,
device,
self.mixed_precision,
)
if self.device is None:
self.feature_extractor = self.feature_extractor.module
logger.info("This model is best performing on 448x448 images.")
@property # type: ignore
def transform(self) -> transforms.Lambda:
"""Transform method to apply element wise."""
return transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
),
]
)
def __call__(self, images: torch.Tensor) -> np.ndarray:
"""Compute and return features.
Args:
images (torch.Tensor): Input of size (n_tiles, n_channels, dim_x, dim_y).
Returns
-------
torch.Tensor: Tensor of size (n_tiles, features_dim).
"""
features = self.feature_extractor(images.to(self.device))
return features.cpu().numpy()