Source code for flamby.datasets.fed_camelyon16.dataset

import logging
import os
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

import flamby.datasets.fed_camelyon16
from flamby.utils import check_dataset_from_config


class Camelyon16Raw(Dataset):
    """Pytorch dataset containing all the features, labels and
    metadata for Camelyon16 WS without any discrimination.

    Parameters
    ----------
    X_dtype : torch.dtype, optional
        Dtype for inputs `X`. Defaults to `torch.float32`.
    y_dtype : torch.dtype, optional
        Dtype for labels `y`. Defaults to `torch.int64`.
    debug : bool, optional,
        Whether or not to use only the part of the dataset downloaded in
        debug mode. Defaults to False.
    data_path: str
        If data_path is given it will ignore the config file and look for the
        dataset directly in data_path. Defaults to None.

    Attributes
    ----------
    tiles_dir : str
        Where all features are located
    labels: pd.DataFrame
        The ground truth DataFrame for labels
    metadata: pd.DataFrame
        The ground truth dataframe for metadata such as centers
    features_paths: list[str]
        The list with the path towards all features.
    features_labels: list[int]
        The list with all classification labels for all features
    features_centers: list[int]
        The list for all centers for all features
    features_sets: list[str]
        The list for all sets (train/test) for all features
    X_dtype: torch.dtype
        The dtype of the X features output
    y_dtype: torch.dtype
        The dtype of the y label output
    debug: bool
        Whether or not we use the dataset with only part of the features
    perms: dict
        The dictionary of all generated permutations for each slide.
    """

    def __init__(
        self, X_dtype=torch.float32, y_dtype=torch.float32, debug=False, data_path=None
    ):
        """See description above"""
        if data_path is None:
            dict = check_dataset_from_config("fed_camelyon16", debug)
            self.tiles_dir = Path(dict["dataset_path"])
        else:
            if not (os.path.exists(data_path)):
                raise ValueError(f"The string {data_path} is not a valid path.")
            self.tiles_dir = Path(data_path)

        path_to_labels_file = str(
            Path(
                os.path.dirname(flamby.datasets.fed_camelyon16.__file__)
                / Path("labels.csv")
            )
        )
        self.labels = pd.read_csv(path_to_labels_file, index_col="filenames")
        self.metadata = pd.read_csv(
            Path(os.path.dirname(flamby.datasets.fed_camelyon16.__file__))
            / Path("metadata")
            / Path("metadata.csv")
        )
        self.X_dtype = X_dtype
        self.y_dtype = y_dtype
        self.debug = debug
        self.features_paths = []
        self.features_labels = []
        self.features_centers = []
        self.features_sets = []
        self.perms = {}
        # We need this list to be sorted for reproducibility but shuffled to
        # avoid weirdness
        # filter out normal_086 and test_049 slides since they have been
        # removed from the Camelyon16 dataset
        npys_list = [
            e
            for e in sorted(self.tiles_dir.glob("*.npy"))
            if e.name.lower() not in ("normal_086.tif.npy", "test_049.tif.npy")
        ]
        random.seed(0)
        random.shuffle(npys_list)
        for slide in npys_list:
            slide_name = os.path.basename(slide).split(".")[0].lower()
            slide_id = int(slide_name.split("_")[1])
            label_from_metadata = int(
                self.metadata.loc[
                    [
                        e.split(".")[0] == slide_name
                        for e in self.metadata["slide_name"].tolist()
                    ],
                    "label",
                ].item()
            )
            center_from_metadata = int(
                self.metadata.loc[
                    [
                        e.split(".")[0] == slide_name
                        for e in self.metadata["slide_name"].tolist()
                    ],
                    "hospital_corrected",
                ].item()
            )
            label_from_data = int(self.labels.loc[slide.name.lower()].tumor)

            if "test" not in str(slide).lower():
                if slide_name.startswith("normal"):
                    # Normal slide
                    if slide_id > 100:
                        center_label = 1

                    else:
                        center_label = 0
                    label_from_slide_name = 0  # Normal slide
                elif slide_name.startswith("tumor"):
                    # Tumor slide
                    if slide_id > 70:
                        center_label = 1
                    else:
                        center_label = 0
                    label_from_slide_name = 1  # Tumor slide
                self.features_sets.append("train")
                assert label_from_slide_name == label_from_data, "This shouldn't happen"
                assert center_label == center_from_metadata, "This shouldn't happen"

            else:
                self.features_sets.append("test")

            assert label_from_metadata == label_from_data
            self.features_paths.append(slide)
            self.features_labels.append(label_from_data)
            self.features_centers.append(center_from_metadata)
        if len(self.features_paths) < len(self.labels.index):
            if not (self.debug):
                logging.warning(
                    f"You have {len(self.features_paths)} features found in"
                    f" {str(self.tiles_dir)} instead of {len(self.labels.index)} (full"
                    " Camelyon16 dataset), please go back to the installation"
                    " instructions."
                )
            else:
                print(
                    "Warning you are operating on a reduced dataset in  DEBUG mode with"
                    " in total {len(self.features_paths)}/{len(self.labels.index)}"
                    " features."
                )

    def __len__(self):
        return len(self.features_paths)

    def __getitem__(self, idx):
        start = 0
        X = np.load(self.features_paths[idx])[:, start:]
        X = torch.from_numpy(X).to(self.X_dtype)
        y = torch.from_numpy(np.asarray(self.features_labels[idx])).to(self.y_dtype)
        if idx not in self.perms:
            self.perms[idx] = np.random.default_rng(42).permutation(X.shape[0])

        return X, y, self.perms[idx]


[docs] class FedCamelyon16(Camelyon16Raw): """ Pytorch dataset containing for each center the features and associated labels for Camelyon16 federated classification. One can instantiate this dataset with train or test data coming from either of the 2 centers it was created from or all data pooled. The train/test split corresponds to the one from the Challenge. Parameters ---------- center : int, optional Default to 0. train : bool, optional Default to True pooled : bool, optional Whether to take all data from the 2 centers into one dataset, by default False X_dtype : torch.dtype, optional Dtype for inputs `X`. Defaults to `torch.float32`. y_dtype : torch.dtype, optional Dtype for labels `y`. Defaults to `torch.int64`. debug : bool, optional, Whether or not to use only the part of the dataset downloaded in debug mode. Defaults to False. data_path: str If data_path is given it will ignore the config file and look for the dataset directly in data_path. Defaults to None. """ def __init__( self, center: int = 0, train: bool = True, pooled: bool = False, X_dtype: torch.dtype = torch.float32, y_dtype: torch.dtype = torch.float32, debug: bool = False, data_path: str = None, ): """ Cf class docstring """ super().__init__( X_dtype=X_dtype, y_dtype=y_dtype, debug=debug, data_path=data_path ) assert center in [0, 1] self.centers = [center] if pooled: self.centers = [0, 1] if train: self.sets = ["train"] else: self.sets = ["test"] to_select = [ (self.features_sets[idx] in self.sets) and (self.features_centers[idx] in self.centers) for idx, _ in enumerate(self.features_centers) ] self.features_paths = [ fp for idx, fp in enumerate(self.features_paths) if to_select[idx] ] self.features_sets = [ fp for idx, fp in enumerate(self.features_sets) if to_select[idx] ] self.features_labels = [ fp for idx, fp in enumerate(self.features_labels) if to_select[idx] ] self.features_centers = [ fp for idx, fp in enumerate(self.features_centers) if to_select[idx] ]
def collate_fn(dataset_elements_list, max_tiles=10000): """Helper function to correctly batch samples from a Camelyon16Dataset accomodating for the uneven number of tiles per slide. Parameters ---------- dataset_elements_list : List[torch.Tensor] A list of torch tensors of dimensions [n, m] with uneven distribution of ns. max_tiles : int, optional The nummber of tiles max by Tensor, by default 10000 Returns ------- Tuple(torch.Tensor, torch.Tensor) X, y two torch tensors of size (len(dataset_elements_list), max_tiles, m) and (len(dataset_elements_list),) """ n = len(dataset_elements_list) X0, y0, _ = dataset_elements_list[0] feature_dim = X0.size(1) X_dtype = X0.dtype y_dtype = y0.dtype X = torch.zeros((n, max_tiles, feature_dim), dtype=X_dtype) y = torch.empty((n, 1), dtype=y_dtype) for i in range(n): X_current, y_current, perm = dataset_elements_list[i] ntiles_min = min(max_tiles, X_current.shape[0]) X[i, :ntiles_min, :] = X_current[perm[:ntiles_min], :] y[i] = y_current return X, y