Source code for flamby.datasets.fed_tcga_brca.dataset

import os
from pathlib import Path

import pandas as pd
import torch

from flamby.utils import accept_license


class TcgaBrcaRaw(torch.utils.data.Dataset):
    """Pytorch dataset containing all the clinical features and (event, time)
    information for TCGA-BRCA survival analysis.
    Attributes
    ----------
    X_dtype: torch.dtyp
        the dtype of the X features output
    y_dtype: torch.dtype,
        the dtype of the (E, T) output
    dic:
        dictionary containing the paths to the data and the train_test_split file
    data:
        pandas dataframe containing the data for the all the patients
    __getitem__:
        returns a tuple, first element is a torch tensor of dimension
        (39,) for the covariates, second element is a torch tensor of dimension (2,)
        for E, T
    """

    def __init__(self, X_dtype=torch.float32, y_dtype=torch.float32):
        accept_license(
            "https://gdc.cancer.gov/access-data/data-access-processes-and-tools",
            "fed_tcga_brca",
        )
        input_path = Path(os.path.realpath(__file__)).parent.resolve()
        self.dic = {
            "input_preprocessed": os.path.join(input_path, "brca.csv"),
            "train_test_split": os.path.join(
                input_path, "dataset_creation_scripts/train_test_split.csv"
            ),
        }
        self.X_dtype = X_dtype
        self.y_dtype = y_dtype
        self.data = pd.read_csv(self.dic["input_preprocessed"])

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        x = self.data.iloc[idx, 1:40]
        y = self.data.iloc[idx, 40:42]
        return (torch.tensor(x, dtype=self.X_dtype), torch.tensor(y, dtype=self.y_dtype))


[docs] class FedTcgaBrca(TcgaBrcaRaw): """ Pytorch dataset containing all the clinical features and (event, time) information for TCGA-BRCA survival analysis. One can instantiate this dataset with train or test data coming from either of the 6 regions or all regions pooled. The train/test split is static and given in the train_test_split file. Parameters ---------- center : int, optional Between 0 and 5, designates the region in the case of pooled==False. Default to 0 train : bool, optional Characterizes if the dataset is used for training or for testing. Default to True pooled : bool, optional Characterizes if the dataset is pooled or not. Default to False X_dtype : torch.dtype, optional Default to torch.float32 y_dtype : torch.dtype, optional Default to torch.float32 """ def __init__( self, center: int = 0, train: bool = True, pooled: bool = False, X_dtype: torch.dtype = torch.float32, y_dtype: torch.dtype = torch.float32, ): """ cf class docstring """ super().__init__(X_dtype=X_dtype, y_dtype=y_dtype) self.center = center self.train_test = "train" if train else "test" self.pooled = pooled self.key = self.train_test + "_" + str(self.center) df = pd.read_csv(self.dic["train_test_split"]) if self.pooled: pids = df.query("fold == '" + self.train_test + "' ").reset_index(drop=True) if not self.pooled: assert center in range(6) pids = df.query("fold2 == '" + self.key + "' ").reset_index(drop=True) pid_list = list(pids["pid"]) df2 = pd.read_csv(self.dic["input_preprocessed"]) self.data = df2[df2["pid"].isin(pid_list)]
if __name__ == "__main__": mydataset = TcgaBrcaRaw() print(len(mydataset)) print("Example of dataset record: ", mydataset[0]) mydataset = FedTcgaBrca(train=True, pooled=True) print(len(mydataset)) print("Example of dataset record: ", mydataset[0]) mydataset = FedTcgaBrca(train=False, pooled=True) print(len(mydataset)) print("Example of dataset record: ", mydataset[0]) for i in range(6): mydataset = FedTcgaBrca(center=i, train=True, pooled=False) print(len(mydataset)) print("Example of dataset record: ", mydataset[0]) mydataset = FedTcgaBrca(center=i, train=False, pooled=False) print(len(mydataset)) print("Example of dataset record: ", mydataset[0]) mydataset = FedTcgaBrca(center=5, train=False, pooled=False) print(len(mydataset)) for i in range(11): print("Example of dataset record: ", mydataset[i])