Source code for flamby.datasets.fed_kits19.dataset

#    Copyright 2020 Division of Medical Image Computing, German Cancer Research Center
#    (DKFZ), Heidelberg, Germany
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

# Part of this file comes from https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunet
# See flamby/datasets/fed_kits19/dataset_creation_scripts/LICENSE/README.md for more
# information

import os
import sys
from collections import OrderedDict
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from batchgenerators.utilities.file_and_folder_operations import (
    isfile,
    join,
    load_pickle,
)
from nnunet.training.data_augmentation.default_data_augmentation import (
    default_3D_augmentation_params,
    get_patch_size,
)
from torch.utils.data import Dataset

import flamby.datasets.fed_kits19
from flamby.datasets.fed_kits19.dataset_creation_scripts.utils import (
    set_environment_variables,
    transformations,
)
from flamby.utils import check_dataset_from_config

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "")))


class Kits19Raw(Dataset):
    """Pytorch dataset containing all the images, and segmentations for KiTS19

    Parameters
    ----------
    X_dtype : torch.dtype, optional
        Dtype for inputs `X`. Defaults to `torch.float32`.
    y_dtype : torch.dtype, optional
        Dtype for labels `y`. Defaults to `torch.int64`.
    debug : bool, optional,
        Whether or not to use only the part of the dataset downloaded in
        debug mode. Defaults to False.
    data_path: str
        If data_path is given it will ignore the config file and look for the
        dataset directly in data_path. Defaults to None.
    """

    def __init__(
        self,
        train=True,
        X_dtype=torch.float32,
        y_dtype=torch.float32,
        debug=False,
        data_path=None,
    ):
        """See description above"""
        # set_environment_variables should be called before importing nnunet
        if data_path is not None:
            if not (os.path.exists(data_path)):
                raise ValueError(f"The string {data_path} is not a valid path.")
        set_environment_variables(debug, data_path=data_path)
        from nnunet.paths import preprocessing_output_dir

        if data_path is None:
            check_dataset_from_config("fed_kits19", debug)

        plans_file = (
            preprocessing_output_dir
            + "/Task064_KiTS_labelsFixed/nnUNetPlansv2.1_plans_3D.pkl"
        )
        plans = load_pickle(plans_file)
        stage_plans = plans["plans_per_stage"][0]
        self.patch_size = np.array(stage_plans["patch_size"]).astype(int)
        data_aug_params = default_3D_augmentation_params
        data_aug_params["patch_size_for_spatialtransform"] = self.patch_size
        basic_generator_patch_size = get_patch_size(
            self.patch_size,
            data_aug_params["rotation_x"],
            data_aug_params["rotation_y"],
            data_aug_params["rotation_z"],
            data_aug_params["scale_range"],
        )

        self.pad_kwargs_data = OrderedDict()
        self.pad_mode = "constant"
        self.need_to_pad = (
            np.array(basic_generator_patch_size) - np.array(self.patch_size)
        ).astype(int)

        self.tr_transform, self.test_transform = transformations(
            data_aug_params["patch_size_for_spatialtransform"], data_aug_params
        )

        self.dataset_directory = (
            preprocessing_output_dir
            + "/Task064_KiTS_labelsFixed/nnUNetData_plans_v2.1_stage0"
        )

        self.X_dtype = X_dtype
        self.y_dtype = y_dtype
        self.debug = debug
        self.train_test = "train" if train else "test"

        df = pd.read_csv(
            Path(os.path.dirname(flamby.datasets.fed_kits19.__file__))
            / Path("metadata")
            / Path("thresholded_sites.csv")
        )
        df2 = df.query("train_test_split == '" + self.train_test + "' ").reset_index(
            drop=True
        )
        self.images = df2.case_ids.tolist()

        # Load image paths and properties files
        c = 0  # Case
        self.images_path = OrderedDict()
        for i in self.images:
            self.images_path[c] = OrderedDict()
            self.images_path[c]["data_file"] = join(self.dataset_directory, "%s.npz" % i)
            self.images_path[c]["properties_file"] = join(
                self.dataset_directory, "%s.pkl" % i
            )
            self.images_path[c]["properties"] = load_pickle(
                self.images_path[c]["properties_file"]
            )
            c += 1
        self.oversample_next_sample = 0
        self.centers = df2.site_ids

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, idx):
        if isfile(self.images_path[idx]["data_file"][:-4] + ".npy"):
            case_all_data = np.load(
                self.images_path[idx]["data_file"][:-4] + ".npy", memmap_mode="r"
            )
        else:
            case_all_data = np.load(self.images_path[idx]["data_file"])["data"]

        properties = self.images_path[idx]["properties"]
        # randomly oversample the foreground classes
        if self.oversample_next_sample == 1:
            self.oversample_next_sample = 0
            item = self.oversample_foreground_class(case_all_data, True, properties)
        else:
            self.oversample_next_sample = 1
            item = self.oversample_foreground_class(case_all_data, False, properties)

        # apply data augmentations
        if self.train_test == "train":
            item = self.tr_transform(**item)
        elif self.train_test == "test":
            item = self.test_transform(**item)

        return np.squeeze(item["data"], axis=1), np.squeeze(item["target"], axis=1)

    def oversample_foreground_class(self, case_all_data, force_fg, properties):
        # taken from nnunet
        data_shape = (1, 1, *self.patch_size)
        seg_shape = (1, 1, *self.patch_size)
        data = np.zeros(data_shape, dtype=np.float32)  # shapes?
        seg = np.zeros(seg_shape, dtype=np.float32)
        need_to_pad = self.need_to_pad.copy()
        for d in range(3):
            # if case_all_data.shape + need_to_pad is still < patch size we need to
            # pad more! We pad on both sides
            # always
            if need_to_pad[d] + case_all_data.shape[d + 1] < self.patch_size[d]:
                need_to_pad[d] = self.patch_size[d] - case_all_data.shape[d + 1]
        # we can now choose the bbox from -need_to_pad // 2 to shape - patch_size +
        # need_to_pad // 2. Here we
        # define what the upper and lower bound can be to then sample form them with
        # np.random.randint
        shape = case_all_data.shape[1:]
        lb_x = -need_to_pad[0] // 2
        ub_x = shape[0] + need_to_pad[0] // 2 + need_to_pad[0] % 2 - self.patch_size[0]
        lb_y = -need_to_pad[1] // 2
        ub_y = shape[1] + need_to_pad[1] // 2 + need_to_pad[1] % 2 - self.patch_size[1]
        lb_z = -need_to_pad[2] // 2
        ub_z = shape[2] + need_to_pad[2] // 2 + need_to_pad[2] % 2 - self.patch_size[2]

        # if not force_fg then we can just sample the bbox randomly from lb and ub.
        # Else we need to make sure we get
        # at least one of the foreground classes in the patch
        if not force_fg:
            bbox_x_lb = np.random.randint(lb_x, ub_x + 1)
            bbox_y_lb = np.random.randint(lb_y, ub_y + 1)
            bbox_z_lb = np.random.randint(lb_z, ub_z + 1)
        else:
            # these values should have been precomputed
            if "class_locations" not in properties.keys():
                raise RuntimeError(
                    "Please rerun the preprocessing with the newest version of nnU-Net!"
                )

            # Foreground Classes = [0, 1]
            # this saves us a np.unique. Preprocessing already did that for all cases.
            # Neat.
            foreground_classes = np.array(
                [
                    i
                    for i in properties["class_locations"].keys()
                    if len(properties["class_locations"][i]) != 0
                ]
            )
            foreground_classes = foreground_classes[foreground_classes > 0]

            if len(foreground_classes) == 0:
                # this only happens if some image does not contain foreground voxels at
                # all
                selected_class = None
                voxels_of_that_class = None
                print("case does not contain any foreground classes")
            else:
                selected_class = np.random.choice(foreground_classes)

                voxels_of_that_class = properties["class_locations"][selected_class]

            if voxels_of_that_class is not None:
                selected_voxel = voxels_of_that_class[
                    np.random.choice(len(voxels_of_that_class))
                ]
                # selected voxel is center voxel. Subtract half the patch size to get
                # lower bbox voxel.
                # Make sure it is within the bounds of lb and ub
                bbox_x_lb = max(lb_x, selected_voxel[0] - self.patch_size[0] // 2)
                bbox_y_lb = max(lb_y, selected_voxel[1] - self.patch_size[1] // 2)
                bbox_z_lb = max(lb_z, selected_voxel[2] - self.patch_size[2] // 2)
            else:
                # If the image does not contain any foreground classes, we fall back to
                # random cropping
                bbox_x_lb = np.random.randint(lb_x, ub_x + 1)
                bbox_y_lb = np.random.randint(lb_y, ub_y + 1)
                bbox_z_lb = np.random.randint(lb_z, ub_z + 1)

        bbox_x_ub = bbox_x_lb + self.patch_size[0]
        bbox_y_ub = bbox_y_lb + self.patch_size[1]
        bbox_z_ub = bbox_z_lb + self.patch_size[2]

        # whoever wrote this knew what he was doing (hint: it was me). We first crop
        # the data to the region of the bbox that actually lies within the data. This
        # will result in a smaller array which is then faster to pad. valid_bbox is
        # just the coord that lied within the data cube. It will be padded to match the
        # patch size later
        valid_bbox_x_lb = max(0, bbox_x_lb)
        valid_bbox_x_ub = min(shape[0], bbox_x_ub)
        valid_bbox_y_lb = max(0, bbox_y_lb)
        valid_bbox_y_ub = min(shape[1], bbox_y_ub)
        valid_bbox_z_lb = max(0, bbox_z_lb)
        valid_bbox_z_ub = min(shape[2], bbox_z_ub)

        # At this point you might ask yourself why we would treat seg differently from
        # seg_from_previous_stage. Why not just concatenate them here and forget about
        # the if statements? Well that's because segneeds to be padded with -1 constant
        # whereas seg_from_previous_stage needs to be padded with 0s (we could also
        # remove label -1 in the data augmentation but this way it is less error prone)
        case_all_data = np.copy(
            case_all_data[
                :,
                valid_bbox_x_lb:valid_bbox_x_ub,
                valid_bbox_y_lb:valid_bbox_y_ub,
                valid_bbox_z_lb:valid_bbox_z_ub,
            ]
        )

        data[0] = np.pad(
            case_all_data[:-1],
            (
                (0, 0),
                (-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)),
                (-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0)),
                (-min(0, bbox_z_lb), max(bbox_z_ub - shape[2], 0)),
            ),
            self.pad_mode,
            **self.pad_kwargs_data,
        )

        seg[0] = np.pad(
            case_all_data[-1:],
            (
                (0, 0),
                (-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)),
                (-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0)),
                (-min(0, bbox_z_lb), max(bbox_z_ub - shape[2], 0)),
            ),
            "constant",
            **{"constant_values": -1},
        )

        return {"data": data, "seg": seg}


[docs] class FedKits19(Kits19Raw): """ Pytorch dataset containing for each center the features and associated labels for Camelyon16 federated classification. One can instantiate this dataset with train or test data coming from either of the 2 centers it was created from or all data pooled. The train/test split corresponds to the one from the Challenge. Parameters ---------- center : int, optional Center id between 0 and 5. Default to 0 train : bool, optional Default to True pooled : bool, optional Default to False X_dtype : torch.dtype, optional Default to torch.float32 y_dtype : torch.dtype, optional Default to torch.float32 debug : bool, optional Whether or not to use only the part of the dataset downloaded in debug mode. Default to False. """ def __init__( self, center: int = 0, train: bool = True, pooled: bool = False, X_dtype: torch.dtype = torch.float32, y_dtype: torch.dtype = torch.float32, debug: bool = False, ): """Cf class docstring""" super().__init__(X_dtype=X_dtype, train=train, y_dtype=y_dtype, debug=debug) key = self.train_test + "_" + str(center) if not pooled: assert center in range(6) df = pd.read_csv( Path(os.path.dirname(flamby.datasets.fed_kits19.__file__)) / Path("metadata") / Path("thresholded_sites.csv") ) df2 = df.query("train_test_split_silo == '" + key + "' ").reset_index( drop=True ) self.images = df2.case_ids.tolist() c = 0 self.images_path = OrderedDict() for i in self.images: self.images_path[c] = OrderedDict() self.images_path[c]["data_file"] = join( self.dataset_directory, "%s.npz" % i ) self.images_path[c]["properties_file"] = join( self.dataset_directory, "%s.pkl" % i ) self.images_path[c]["properties"] = load_pickle( self.images_path[c]["properties_file"] ) c += 1 self.centers = df2.site_ids