Source code for plismbench.engine.extract.extract_from_png
"""Stream PLISM tiles dataset and extract features on-the-fly for a given model."""
from __future__ import annotations
from collections.abc import Callable
from functools import partial
from math import ceil
from pathlib import Path
import datasets
import numpy as np
import torch
from loguru import logger
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from plismbench.engine.extract.utils import (
NUM_SLIDES,
NUM_TILES_PER_SLIDE,
process_imgs,
save_features,
)
from plismbench.models import FeatureExtractorsEnum
[docs]
def collate(
batch: list[dict[str, str | Image.Image]],
transform: Callable[[np.ndarray], torch.Tensor],
) -> tuple[list[str], list[str], torch.Tensor]:
"""Return slide ids, tile ids and transformed images.
Parameters
----------
batch: list[dict[str, str | Image.Image]],
List of length ``batch_size`` made of dictionnaries.
Each dictionnary is a single input with keys: 'slide_id',
'tile_id' and 'png'. The image is a ``PIL.Image.Image``
with type unit8 (0-255)
transform: Callable[[np.ndarray], torch.Tensor]
Transform function taking ``np.ndarray`` image as inputs.
Prior to calling this transform function, conversion from a
``PIL.Image.Image`` to an array is performed.
Returns
-------
output: tuple[list[str], list[str], torch.Tensor]
A tuple made of slides ids, tiles ids and transformed input images.
"""
slide_ids: list[str] = [b["slide_id"] for b in batch] # type: ignore
tile_ids: list[str] = [b["tile_id"] for b in batch] # type: ignore
imgs = torch.stack([transform(np.array(b["png"])) for b in batch])
output = (slide_ids, tile_ids, imgs)
return output
[docs]
def run_extract_streaming(
feature_extractor_name: str,
batch_size: int,
device: int,
export_dir: Path,
overwrite: bool = False,
) -> None:
"""Run features extraction with streaming."""
if overwrite:
logger.warning("You are about to overwrite existing features.")
logger.info(f"Export directory set to {str(export_dir)}.")
# Create export directory if it doesn't exist
export_dir.mkdir(exist_ok=True, parents=True)
# Initialize the feature extractor
feature_extractor = FeatureExtractorsEnum[feature_extractor_name.upper()].init(
device=device
)
image_transform = feature_extractor.transform
device = feature_extractor.device
# Create the dataset and dataloader without actually loading the files to disk (`streaming=True`)
# The dataset is sorted by slide_id, meaning that the first 16278 indexes belong to the same first slide,
# then 16278:32556 to the second slide, etc.
dataset = datasets.load_dataset(
"owkin/plism-dataset-tiles", split="train", streaming=True
)
collate_fn = partial(collate, transform=image_transform)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=collate_fn,
num_workers=0,
pin_memory=True,
shuffle=False,
)
# Iterate over the full dataset and store features each time 16278 input images have been processed
slide_features = []
current_num_tiles = 0
existing_slide_already_checked = False
for slide_ids, tile_ids, imgs in tqdm(
dataloader,
total=ceil(NUM_SLIDES * NUM_TILES_PER_SLIDE / batch_size),
desc="Extracting features",
):
reference_slide_id = slide_ids[0]
# Get output path for features
slide_features_export_dir = Path(export_dir / reference_slide_id)
slide_features_export_path = slide_features_export_dir / "features.npy"
if slide_features_export_path.exists():
if not existing_slide_already_checked:
if overwrite:
logger.info(
f"Features for slide {reference_slide_id} already extracted. Overwriting..."
)
existing_slide_already_checked = True
else:
logger.info(
f"Features for slide {reference_slide_id} already extracted. Skipping..."
)
existing_slide_already_checked = True
continue
slide_features_export_dir.mkdir(exist_ok=True, parents=True)
# If we're on the same slide, we just add the batch features to the running list
if all(slide_id == reference_slide_id for slide_id in slide_ids):
batch_stack = process_imgs(imgs, tile_ids, model=feature_extractor)
slide_features.append(batch_stack)
# For the very last slide, the last batch may be of size < `batch_size`
current_num_tiles += batch_stack.shape[0]
# If the current batch contains exactly the last `batch_size` tile features for the slide,
# export the slide features and reset `slide_features` and `current_num_tiles`
if current_num_tiles == NUM_TILES_PER_SLIDE:
save_features(
slide_features,
slide_id=reference_slide_id,
export_path=slide_features_export_path,
)
logger.success(
f"Successfully saved features for slide: {reference_slide_id}"
)
slide_features = []
current_num_tiles = 0
existing_slide_already_checked = False
# The current batch contains tiles from slide N (`reference_slide_id`) and slide N+1
else:
# We retrieve the maximum index at which all tiles in the batch comes from slide N
mask = np.array(slide_ids) != reference_slide_id
idx = mask.argmax()
# And only process the later, then export the slides features
batch_stack = process_imgs(
imgs[:idx], tile_ids[:idx], model=feature_extractor
)
slide_features.append(batch_stack)
save_features(
slide_features,
slide_id=reference_slide_id,
export_path=slide_features_export_path,
)
logger.success(
f"Successfully saved features for slide: {reference_slide_id}"
)
# We initialize `slide_features` and `current_num_tiles` with respectively
# the tile features from slide N+1
slide_features = [
process_imgs(imgs[idx:], tile_ids[idx:], model=feature_extractor)
]
current_num_tiles = batch_size - idx
existing_slide_already_checked = False