Source code for plismbench.engine.extract.utils

"""Utility functionalities for the extraction pipeline."""

from pathlib import Path

import numpy as np
import pandas as pd
import torch


# Do not touch those values as PLISM dataset contains 91 slides x 16278 tiles
NUM_SLIDES: int = 91
NUM_TILES_PER_SLIDE: int = 16_278


[docs] def sort_coords(slide_features: np.ndarray) -> np.ndarray: """Sort slide features by coordinates.""" slide_coords = pd.DataFrame(slide_features[:, 1:3], columns=["x", "y"]) slide_coords.sort_values(["x", "y"], inplace=True) new_index = slide_coords.index.values return slide_features[new_index]
[docs] def save_features( slide_features: list[np.ndarray], slide_id: str, export_path: Path, ) -> None: """Save features to disk. Parameters ---------- slide_features: list[np.ndarray] Current slide features. slide_id: str Current slide id. export_path: Path Export path for slide features. """ _output_slide_features = np.concatenate(slide_features, axis=0).astype(np.float32) output_slide_features = sort_coords(_output_slide_features) slide_num_tiles = output_slide_features.shape[0] assert slide_num_tiles == NUM_TILES_PER_SLIDE, ( f"Output features for slide {slide_id} contains {slide_num_tiles} < {NUM_TILES_PER_SLIDE}." ) np.save(export_path, output_slide_features)
[docs] def process_imgs( imgs: torch.Tensor, tile_ids: list[str], model: torch.nn.Module ) -> np.ndarray: """Perform inference on input (already transformed) images. Parameters ---------- imgs: torch.Tensor Transformed images (e.g. normalized, cropped, etc.). tile_ids: list[str]: List of tile ids. model: torch.nn.Module Feature extractor. """ with torch.inference_mode(): batch_features = model(imgs).squeeze() # (N_tiles, d) numpy array batch_tiles_coordinates = np.array( [tile_id.split("_")[1:] for tile_id in tile_ids] ).astype(int) # (N_tiles, 3) numpy array batch_stack = np.concatenate([batch_tiles_coordinates, batch_features], axis=1) return batch_stack