import os
from pathlib import Path
from typing import Dict, Tuple
from zipfile import ZipFile
import pandas as pd
import torch
from monai.transforms import (
AddChannel,
AsDiscrete,
Compose,
NormalizeIntensity,
Resize,
ToTensor,
)
from torch import Tensor
from torch.utils.data import Dataset
import flamby
from flamby.datasets.fed_ixi.utils import (
_extract_center_name_from_filename,
_get_center_name_from_center_id,
_get_id_from_filename,
_load_nifti_image_and_label_by_id,
)
from flamby.utils import check_dataset_from_config
class IXITinyRaw(Dataset):
"""
Generic interface for IXI Tiny Dataset
Parameters
----------
transform : optional
PyTorch Transform to process the data or augment it. Default to None
debug : bool, optional
Default to False
data_path: str
If data_path is given it will ignore the config file and look for the
dataset directly in data_path. Defaults to None.
"""
CENTER_LABELS = {"Guys": 0, "HH": 1, "IOP": 2}
def __init__(self, transform=None, debug=False, data_path=None):
if data_path is None:
dict = check_dataset_from_config("fed_ixi", debug)
self.root_folder = Path(dict["dataset_path"])
else:
if not (os.path.exists(data_path)):
raise ValueError(f"The string {data_path} is not a valid path.")
self.root_folder = Path(data_path)
self.metadata = pd.read_csv(
Path(os.path.dirname(flamby.datasets.fed_ixi.__file__))
/ Path("metadata")
/ Path("metadata_tiny.csv"),
index_col="Patient ID",
)
# pd.read_csv('./metadata/metadata_tiny.csv')
self.common_shape = (48, 60, 48)
self.transform = transform
self.modality = "T1"
# Download of the ixi tiny must be completed and extracted to run this part
# Deferring the import to avoid circular imports
from flamby.datasets.fed_ixi.common import DATASET_URL, FOLDER
self.image_url = DATASET_URL
self.parent_folder = FOLDER
self.parent_dir_name = os.path.join(self.parent_folder, "IXI_sample")
self.subjects_dir = os.path.join(self.root_folder, self.parent_dir_name)
# contains paths of archives which contain a nifti image for each subject
self.images_paths = []
# contains paths of archives which contain a label (binary brain mask) for
# each subject
self.labels_paths = []
self.images_centers = [] # contains center of each subject: HH, Guys or IOP
self.images_sets = [] # train and test
self.subjects = [
subject
for subject in os.listdir(self.subjects_dir)
if os.path.isdir(os.path.join(self.subjects_dir, subject))
]
self.images_centers = [
_extract_center_name_from_filename(subject) for subject in self.subjects
]
self.demographics = Path(os.path.join(self.subjects_dir, "IXI.xls"))
for subject in self.subjects:
patient_id = _get_id_from_filename(subject)
self.images_sets.append(self.metadata.loc[patient_id, "Split"])
subject_dir = os.path.join(self.subjects_dir, subject)
image_path = Path(os.path.join(subject_dir, "T1"))
label_path = Path(os.path.join(subject_dir, "label"))
self.images_paths.extend(image_path.glob("*.nii.gz"))
self.labels_paths.extend(label_path.glob("*.nii.gz"))
self.filenames = [filename.name for filename in self.images_paths]
self.subject_ids = tuple(map(_get_id_from_filename, self.filenames))
@property
def zip_file(self) -> ZipFile:
zf = self.root_folder.joinpath(self.parent_folder + ".zip")
return ZipFile(zf)
def _validate_center(self) -> None:
"""
Asserts permitted image center keys.
Allowed values are:
- 0
- 1
- 2
- Guys
- HH
- IOP
Raises
-------
AssertionError
If `center` argument is not contained amongst possible centers.
"""
centers = list(self.CENTER_LABELS.keys()) + list(self.CENTER_LABELS.values())
assert self.centers[0] in centers, (
f"Center {self.centers[0]} "
"is not compatible with this dataset. "
f"Existing centers can be named as follow: {centers} "
)
def __getitem__(self, item) -> Tuple[Tensor, Dict]:
patient_id = self.subject_ids[item]
header_img, img, label, center_name = _load_nifti_image_and_label_by_id(
zip_file=self.zip_file, patient_id=patient_id, modality=self.modality
)
default_transform = Compose(
[ToTensor(), AddChannel(), Resize(self.common_shape)]
)
intensity_transform = Compose([NormalizeIntensity()])
one_hot_transform = Compose([AsDiscrete(to_onehot=2)])
img = default_transform(img)
img = intensity_transform(img)
label = default_transform(label)
label = one_hot_transform(label)
# metadata = {
# "IXI_ID": patient_id,
# "center": center_name,
# "center_label": self.CENTER_LABELS[center_name],
# }
if self.transform:
img = self.transform(img)
return img.to(torch.float32), label
def __len__(self) -> int:
return len(self.images_paths)
[docs]
class FedIXITiny(IXITinyRaw):
"""
Federated class for T1 images in IXI Tiny Dataset
Parameters
----------
transform:
PyTorch Transform to process the data or augment it.
center: int, optional
Id of the center (hospital) from which to gather data. Defaults to 0.
train : bool, optional
Whether to take the train or test split. Defaults to True (train).
pooled : bool, optional
Whether to take all data from the 3 centers into one dataset.
If True, supersedes center argument. Defaults to False.
debug : bool, optional
Default to False.
data_path: str
If data_path is given it will ignore the config file and look for the
dataset directly in data_path. Defaults to None.
"""
def __init__(
self,
transform=None,
center=0,
train=True,
pooled=False,
debug=False,
data_path=None,
):
"""
Cf class docstring
"""
super(FedIXITiny, self).__init__(
transform=transform, debug=debug, data_path=data_path
)
self.modality = "T1"
self.centers = [center]
self._validate_center()
if isinstance(center, int):
self.centers = [_get_center_name_from_center_id(self.CENTER_LABELS, center)]
if pooled:
self.centers = ["Guys", "HH", "IOP"]
if train:
self.sets = ["train"]
else:
self.sets = ["test"]
to_select = [
(self.images_centers[idx] in self.centers)
and (self.images_sets[idx] in self.sets)
for idx, _ in enumerate(self.images_centers)
]
self.images_paths = [self.images_paths[i] for i, s in enumerate(to_select) if s]
self.labels_paths = [self.labels_paths[i] for i, s in enumerate(to_select) if s]
self.images_centers = [
self.images_centers[i] for i, s in enumerate(to_select) if s
]
self.images_sets = [self.images_sets[i] for i, s in enumerate(to_select) if s]
self.filenames = [filename.name for filename in self.images_paths]
self.subject_ids = tuple(map(_get_id_from_filename, self.filenames))
if __name__ == "__main__":
a = IXITinyRaw()
print("IXI Tiny dataset size:", len(a))
# print('First entry:', a[0])
a = FedIXITiny()
print(
"Data gathered in this federated dataset is from:",
*a.centers,
"and",
*a.sets,
"set",
)
print("Federated dataset size:", len(a))
print("First entry:", a[0])
__all__ = ["IXITinyRaw", "FedIXITiny"]