import os
import random
from pathlib import Path
import albumentations
import numpy as np
import pandas as pd
import torch
from PIL import Image
from flamby.utils import check_dataset_from_config
class Isic2019Raw(torch.utils.data.Dataset):
"""Pytorch dataset containing all the features, labels and datacenter
information for Isic2019.
Attributes
----------
image_paths: list[str]
the list with the path towards all features
targets: list[int]
the list with all classification labels for all features
centers: list[int]
the list for all datacenters for all features
X_dtype: torch.dtype
the dtype of the X features output
y_dtype: torch.dtype
the dtype of the y label output
augmentations:
image transform operations from the albumentations library,
used for data augmentation
data_path: str
If data_path is given it will ignore the config file and look for the
dataset directly in data_path. Defaults to None.
Parameters
----------
X_dtype :
y_dtype :
augmentations :
"""
def __init__(
self,
X_dtype=torch.float32,
y_dtype=torch.int64,
augmentations=None,
data_path=None,
):
"""
Cf class docstring
"""
if data_path is None:
dict = check_dataset_from_config(dataset_name="fed_isic2019", debug=False)
input_path = dict["dataset_path"]
else:
if not (os.path.exists(data_path)):
raise ValueError(f"The string {data_path} is not a valid path.")
input_path = data_path
dir = str(Path(os.path.realpath(__file__)).parent.resolve())
self.dic = {
"input_preprocessed": os.path.join(
input_path, "ISIC_2019_Training_Input_preprocessed"
),
"train_test_split": os.path.join(
dir, "dataset_creation_scripts/train_test_split"
),
}
self.X_dtype = X_dtype
self.y_dtype = y_dtype
df2 = pd.read_csv(self.dic["train_test_split"])
images = df2.image.tolist()
self.image_paths = [
os.path.join(self.dic["input_preprocessed"], image_name + ".jpg")
for image_name in images
]
self.targets = df2.target
self.augmentations = augmentations
self.centers = df2.center
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
image = np.array(Image.open(image_path))
target = self.targets[idx]
# Image augmentations
if self.augmentations is not None:
augmented = self.augmentations(image=image)
image = augmented["image"]
image = np.transpose(image, (2, 0, 1)).astype(np.float32)
return (
torch.tensor(image, dtype=self.X_dtype),
torch.tensor(target, dtype=self.y_dtype),
)
[docs]
class FedIsic2019(Isic2019Raw):
"""
Pytorch dataset containing for each center the features and associated labels
for the Isic2019 federated classification.
One can instantiate this dataset with train or test data coming from either of
the 6 centers it was created from or all data pooled.
The train/test split is fixed and given in the train_test_split file.
Parameters
----------
center : int, optional
Default to 0
train : bool, optional
Default to True
pooled : bool, optional
Default to False
debug : bool, optional
Default to False
X_dtype : torch.dtype, optional
Default to torch.float32
y_dtype : torch.dtype, optional
Default to torch.int64
data_path: str
If data_path is given it will ignore the config file and look for the
dataset directly in data_path. Defaults to None.
"""
def __init__(
self,
center: int = 0,
train: bool = True,
pooled: bool = False,
debug: bool = False,
X_dtype: torch.dtype = torch.float32,
y_dtype: torch.dtype = torch.int64,
data_path: str = None,
):
"""Cf class docstring"""
sz = 200
if train:
augmentations = albumentations.Compose(
[
albumentations.RandomScale(0.07),
albumentations.Rotate(50),
albumentations.RandomBrightnessContrast(0.15, 0.1),
albumentations.Flip(p=0.5),
albumentations.Affine(shear=0.1),
albumentations.RandomCrop(sz, sz),
albumentations.CoarseDropout(random.randint(1, 8), 16, 16),
albumentations.Normalize(always_apply=True),
]
)
else:
augmentations = albumentations.Compose(
[
albumentations.CenterCrop(sz, sz),
albumentations.Normalize(always_apply=True),
]
)
super().__init__(
X_dtype=X_dtype,
y_dtype=y_dtype,
augmentations=augmentations,
data_path=data_path,
)
self.center = center
self.train_test = "train" if train else "test"
self.pooled = pooled
self.key = self.train_test + "_" + str(self.center)
df = pd.read_csv(self.dic["train_test_split"])
if self.pooled:
df2 = df.query("fold == '" + self.train_test + "' ").reset_index(drop=True)
if not self.pooled:
assert center in range(6)
df2 = df.query("fold2 == '" + self.key + "' ").reset_index(drop=True)
images = df2.image.tolist()
self.image_paths = [
os.path.join(self.dic["input_preprocessed"], image_name + ".jpg")
for image_name in images
]
self.targets = df2.target
self.centers = df2.center
if __name__ == "__main__":
mydataset = Isic2019Raw()
print("Example of dataset record: ", mydataset[0])
print(f"The dataset has {len(mydataset)} elements")
for i in range(10):
print(f"Size of image {i} ", mydataset[i][0].shape)
print(f"Target {i} ", mydataset[i][1])
mydataset = FedIsic2019(train=True, pooled=True)
print(len(mydataset))
print("Size of image 0 ", mydataset[0][0].shape)
mydataset = FedIsic2019(train=False, pooled=True)
print(len(mydataset))
print("Size of image 0 ", mydataset[0][0].shape)
for i in range(6):
mydataset = FedIsic2019(center=i, train=True, pooled=False)
print(len(mydataset))
print("Size of image 0 ", mydataset[0][0].shape)
mydataset = FedIsic2019(center=i, train=False, pooled=False)
print(len(mydataset))
print("Size of image 0 ", mydataset[0][0].shape)
mydataset = FedIsic2019(center=5, train=False, pooled=False)
print(len(mydataset))
for i in range(11):
print(f"Size of image {i} ", mydataset[i][0].shape)