Source code for plismbench.engine.extract.extract_from_h5
"""Download PLISM tiles dataset as h5 files and extract features for a given model."""
from __future__ import annotations
from collections.abc import Callable
from functools import partial
from pathlib import Path
import h5py
import numpy as np
import torch
from loguru import logger
from torch.utils.data import DataLoader
from tqdm import tqdm
from plismbench.engine.extract.utils import (
NUM_TILES_PER_SLIDE,
process_imgs,
save_features,
)
from plismbench.models import FeatureExtractorsEnum
[docs]
class H5Dataset(torch.utils.data.Dataset):
"""Dataset wrapper iterating over a .h5 file content.
Parameters
----------
file_path: Path
Path to the .h5 file.
"""
def __init__(self, file_path: Path):
super().__init__()
self.file_path = file_path
self.data = h5py.File(self.file_path, "r", libver="latest", swmr=True)
self.keys = list(self.data.keys())
def __len__(self):
"""Get length of dataset."""
length = len(self.keys)
assert length == NUM_TILES_PER_SLIDE, (
f"H5 file for slide {self.file_path.stem} does not contain {NUM_TILES_PER_SLIDE} tiles!"
)
return length
def __getitem__(self, idx):
"""Get next item (``tile_id``, ``tile_array``)."""
tile_id = self.keys[idx]
tile_array = self.data[tile_id][:]
return tile_id, tile_array
[docs]
def collate(
batch: list[tuple[str, torch.Tensor]],
transform: Callable[[np.ndarray], torch.Tensor],
) -> tuple[list[str], torch.Tensor]:
"""Return tile ids and transformed images.
Parameters
----------
batch: list[dict[str, Any]]
List of length ``batch_size`` made of tuples.
Each tuple represents a tile_id and the corresponding image.
The image is a torch.float32 tensor (between 0 and 1).
transform: Callable[[np.ndarray], torch.Tensor]
Transform function taking ``np.ndarray`` image as inputs.
Returns
-------
output: tuple[list[str], torch.Tensor]
A tuple made of tiles ids and transformed input images.
"""
tile_ids: list[str] = [b[0] for b in batch]
raw_imgs: list[np.ndarray] = [b[1] for b in batch] # type: ignore
imgs = torch.stack([transform(img) for img in raw_imgs])
output = (tile_ids, imgs)
return output
[docs]
def get_dataloader(
slide_h5_path: Path,
transform: Callable[[np.ndarray], torch.Tensor],
batch_size: int = 32,
workers: int = 8,
) -> DataLoader:
"""Get PLISM tiles dataset dataloader transformed with ``transform`` function.
Parameters
----------
slide_h5_path: Path
Path to the .h5 containing tiles for a given slide.
transform: Callable[[np.ndarray], torch.Tensor]
Transform function taking ``np.ndarray`` image as inputs.
batch_size: int = 32
Batch size for features extraction.
workers: int = 8
Number of workers to load images.
Returns
-------
dataloader: DataLoader
DataLoader returning (tile_ids, images).
See ``collate`` function for details.
"""
dataset = H5Dataset(file_path=slide_h5_path)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=partial(collate, transform=transform),
num_workers=workers,
pin_memory=True,
shuffle=False,
)
return dataloader
[docs]
def run_extract_h5(
feature_extractor_name: str,
batch_size: int,
device: int,
export_dir: Path,
download_dir: Path,
overwrite: bool = False,
workers: int = 8,
) -> None:
"""Run features extraction."""
if overwrite:
logger.warning("You are about to overwrite existing features.")
logger.info(f"Download directory set to {str(download_dir)}.")
logger.info(f"Export directory set to {str(export_dir)}.")
# Create export directory if it doesn't exist
export_dir.mkdir(exist_ok=True, parents=True)
# Initialize the feature extractor
feature_extractor = FeatureExtractorsEnum[feature_extractor_name.upper()].init(
device=device
)
image_transform = feature_extractor.transform
device = feature_extractor.device
slide_h5_paths = list(download_dir.glob("*_to_GMH_S60.tif.h5"))
# assert (n_slides := len(slide_h5_paths)) == NUM_SLIDES, (
# f"Download uncomplete: found {n_slides}/{NUM_SLIDES}"
# )
for slide_h5_path in tqdm(slide_h5_paths):
# Get slide id
slide_id = slide_h5_path.stem
# Get output path for features
slide_features_export_dir = Path(export_dir / slide_id)
slide_features_export_path = slide_features_export_dir / "features.npy"
if slide_features_export_path.exists():
if overwrite:
logger.info(
f"Features for slide {slide_id} already extracted. Overwriting..."
)
else:
logger.info(
f"Features for slide {slide_id} already extracted. Skipping..."
)
continue
slide_features_export_dir.mkdir(exist_ok=True, parents=True)
# Instanciate the dataloader
dataloader = get_dataloader(
slide_h5_path=slide_h5_path,
transform=image_transform,
batch_size=batch_size,
workers=workers,
)
# Iterate over the full dataset and store features each time 16,278 input images have been processed
slide_features: list[np.ndarray] = []
for tile_ids, tile_images in tqdm(
dataloader, total=len(dataloader), leave=False
):
batch_stack = process_imgs(tile_images, tile_ids, model=feature_extractor)
slide_features.append(batch_stack)
save_features(
slide_features,
slide_id=slide_id,
export_path=slide_features_export_path,
)
logger.success(f"Successfully saved features for slide: {slide_id}")