Skip to content

Estimating LFCs and dispersions

compute_MAP_dispersions

Module containing the mixin class to compute MAP dispersions.

compute_MAP_dispersions

Main module to compute dispersions by minimizing the MLE using a grid search.

ComputeMAPDispersions

Bases: LocFilterMAPDispersions, ComputeDispersionsGridSearch

Mixin class to implement the computation of MAP dispersions.

Methods:

Name Description
fit_MAP_dispersions

A method to fit the MAP dispersions and filter them. The filtering is done by removing the dispersions that are too far from the trend curve.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/compute_MAP_dispersions.py
class ComputeMAPDispersions(
    LocFilterMAPDispersions,
    ComputeDispersionsGridSearch,
):
    """
    Mixin class to implement the computation of MAP dispersions.

    Methods
    -------
    fit_MAP_dispersions
        A method to fit the MAP dispersions and filter them.
        The filtering is done by removing the dispersions that are too far from the
        trend curve.

    """

    def fit_MAP_dispersions(
        self,
        train_data_nodes,
        aggregation_node,
        local_states,
        shared_state,
        round_idx,
        clean_models,
        refit_mode: bool = False,
    ):
        """Fit MAP dispersions, and apply filtering.

        Parameters
        ----------
        train_data_nodes: list
            List of TrainDataNode.

        aggregation_node: AggregationNode
            The aggregation node.

        local_states: dict
            Local states. Required to propagate intermediate results.

        shared_state: dict
            Contains the output of the trend fitting,
            that is a dictionary with a "fitted_dispersion" field containing
            the fitted dispersions from the trend curve, a "prior_disp_var" field
            containing the prior variance of the dispersions, and a "_squared_logres"
            field containing the squared residuals of the trend fitting.

        round_idx: int
            The current round.

        clean_models: bool
            Whether to clean the models after the computation.

        refit_mode: bool
            Whether to run the pipeline in refit mode, after cooks outliers were
            replaced. If True, the pipeline will be run on `refit_adata`s instead of
            `local_adata`s. (default: False).


        Returns
        -------
        local_states: dict
            Local states. Required to propagate intermediate results.

        round_idx: int
            The updated round index.
        """
        local_states, shared_state, round_idx = self.fit_dispersions(
            train_data_nodes,
            aggregation_node,
            local_states,
            shared_state=shared_state,
            round_idx=round_idx,
            clean_models=clean_models,
            fit_mode="MAP",
            refit_mode=refit_mode,
        )

        # Filter the MAP dispersions.
        local_states, _, round_idx = local_step(
            local_method=self.filter_outlier_genes,
            train_data_nodes=train_data_nodes,
            output_local_states=local_states,
            input_local_states=local_states,
            input_shared_state=shared_state,
            aggregation_id=aggregation_node.organization_id,
            description="Filter MAP dispersions.",
            round_idx=round_idx,
            clean_models=clean_models,
            method_params={"refit_mode": refit_mode},
        )

        return local_states, round_idx
fit_MAP_dispersions(train_data_nodes, aggregation_node, local_states, shared_state, round_idx, clean_models, refit_mode=False)

Fit MAP dispersions, and apply filtering.

Parameters:

Name Type Description Default
train_data_nodes

List of TrainDataNode.

required
aggregation_node

The aggregation node.

required
local_states

Local states. Required to propagate intermediate results.

required
shared_state

Contains the output of the trend fitting, that is a dictionary with a "fitted_dispersion" field containing the fitted dispersions from the trend curve, a "prior_disp_var" field containing the prior variance of the dispersions, and a "_squared_logres" field containing the squared residuals of the trend fitting.

required
round_idx

The current round.

required
clean_models

Whether to clean the models after the computation.

required
refit_mode bool

Whether to run the pipeline in refit mode, after cooks outliers were replaced. If True, the pipeline will be run on refit_adatas instead of local_adatas. (default: False).

False

Returns:

Name Type Description
local_states dict

Local states. Required to propagate intermediate results.

round_idx int

The updated round index.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/compute_MAP_dispersions.py
def fit_MAP_dispersions(
    self,
    train_data_nodes,
    aggregation_node,
    local_states,
    shared_state,
    round_idx,
    clean_models,
    refit_mode: bool = False,
):
    """Fit MAP dispersions, and apply filtering.

    Parameters
    ----------
    train_data_nodes: list
        List of TrainDataNode.

    aggregation_node: AggregationNode
        The aggregation node.

    local_states: dict
        Local states. Required to propagate intermediate results.

    shared_state: dict
        Contains the output of the trend fitting,
        that is a dictionary with a "fitted_dispersion" field containing
        the fitted dispersions from the trend curve, a "prior_disp_var" field
        containing the prior variance of the dispersions, and a "_squared_logres"
        field containing the squared residuals of the trend fitting.

    round_idx: int
        The current round.

    clean_models: bool
        Whether to clean the models after the computation.

    refit_mode: bool
        Whether to run the pipeline in refit mode, after cooks outliers were
        replaced. If True, the pipeline will be run on `refit_adata`s instead of
        `local_adata`s. (default: False).


    Returns
    -------
    local_states: dict
        Local states. Required to propagate intermediate results.

    round_idx: int
        The updated round index.
    """
    local_states, shared_state, round_idx = self.fit_dispersions(
        train_data_nodes,
        aggregation_node,
        local_states,
        shared_state=shared_state,
        round_idx=round_idx,
        clean_models=clean_models,
        fit_mode="MAP",
        refit_mode=refit_mode,
    )

    # Filter the MAP dispersions.
    local_states, _, round_idx = local_step(
        local_method=self.filter_outlier_genes,
        train_data_nodes=train_data_nodes,
        output_local_states=local_states,
        input_local_states=local_states,
        input_shared_state=shared_state,
        aggregation_id=aggregation_node.organization_id,
        description="Filter MAP dispersions.",
        round_idx=round_idx,
        clean_models=clean_models,
        method_params={"refit_mode": refit_mode},
    )

    return local_states, round_idx

substeps

LocFilterMAPDispersions

Mixin to filter MAP dispersions and obtain the final dispersion estimates.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/substeps.py
class LocFilterMAPDispersions:
    """Mixin to filter MAP dispersions and obtain the final dispersion estimates."""

    local_adata: AnnData
    refit_adata: AnnData

    @remote_data
    @log_remote_data
    @reconstruct_adatas
    def filter_outlier_genes(
        self,
        data_from_opener,
        shared_state,
        refit_mode: bool = False,
    ) -> None:
        """Filter out outlier genes.

        Avoids shrinking the dispersions of genes that are too far from the trend curve.

        Parameters
        ----------
        data_from_opener : ad.AnnData
            Not used.

        shared_state : dict
            Contains:
            - "MAP_dispersions": MAP dispersions,

        refit_mode : bool
            Whether to run the pipeline on `refit_adata`s instead of `local_adata`s.
            (default: False).
        """
        if refit_mode:
            adata = self.refit_adata
        else:
            adata = self.local_adata

        adata.varm["MAP_dispersions"] = shared_state["MAP_dispersions"].copy()

        adata.varm["dispersions"] = adata.varm["MAP_dispersions"].copy()
        adata.varm["_outlier_genes"] = np.log(
            adata.varm["genewise_dispersions"]
        ) > np.log(adata.varm["fitted_dispersions"]) + 2 * np.sqrt(
            adata.uns["_squared_logres"]
        )
        adata.varm["dispersions"][adata.varm["_outlier_genes"]] = adata.varm[
            "genewise_dispersions"
        ][adata.varm["_outlier_genes"]]
filter_outlier_genes(data_from_opener, shared_state, refit_mode=False)

Filter out outlier genes.

Avoids shrinking the dispersions of genes that are too far from the trend curve.

Parameters:

Name Type Description Default
data_from_opener AnnData

Not used.

required
shared_state dict

Contains: - "MAP_dispersions": MAP dispersions,

required
refit_mode bool

Whether to run the pipeline on refit_adatas instead of local_adatas. (default: False).

False
Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_MAP_dispersions/substeps.py
@remote_data
@log_remote_data
@reconstruct_adatas
def filter_outlier_genes(
    self,
    data_from_opener,
    shared_state,
    refit_mode: bool = False,
) -> None:
    """Filter out outlier genes.

    Avoids shrinking the dispersions of genes that are too far from the trend curve.

    Parameters
    ----------
    data_from_opener : ad.AnnData
        Not used.

    shared_state : dict
        Contains:
        - "MAP_dispersions": MAP dispersions,

    refit_mode : bool
        Whether to run the pipeline on `refit_adata`s instead of `local_adata`s.
        (default: False).
    """
    if refit_mode:
        adata = self.refit_adata
    else:
        adata = self.local_adata

    adata.varm["MAP_dispersions"] = shared_state["MAP_dispersions"].copy()

    adata.varm["dispersions"] = adata.varm["MAP_dispersions"].copy()
    adata.varm["_outlier_genes"] = np.log(
        adata.varm["genewise_dispersions"]
    ) > np.log(adata.varm["fitted_dispersions"]) + 2 * np.sqrt(
        adata.uns["_squared_logres"]
    )
    adata.varm["dispersions"][adata.varm["_outlier_genes"]] = adata.varm[
        "genewise_dispersions"
    ][adata.varm["_outlier_genes"]]

compute_dispersion_prior

compute_dispersion_prior

Module containing the steps for fitting the dispersion trend.

ComputeDispersionPrior

Bases: AggFitDispersionTrendAndPrior, LocGetMeanDispersionAndMean, LocUpdateFittedDispersions

Mixin class to implement the fit of the dispersion trend.

Methods:

Name Description
compute_dispersion_prior

The method to fit the dispersion trend.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/compute_dispersion_prior.py
class ComputeDispersionPrior(
    AggFitDispersionTrendAndPrior,
    LocGetMeanDispersionAndMean,
    LocUpdateFittedDispersions,
):
    """Mixin class to implement the fit of the dispersion trend.

    Methods
    -------
    compute_dispersion_prior
        The method to fit the dispersion trend.

    """

    def compute_dispersion_prior(
        self,
        train_data_nodes,
        aggregation_node,
        local_states,
        genewise_dispersions_shared_state,
        round_idx,
        clean_models,
    ):
        """Fit the dispersion trend.

        Parameters
        ----------
        train_data_nodes: list
            list of TrainDataNode.

        aggregation_node: AggregationNode
            The aggregation node.

        local_states: dict
            Local states. Required to propagate intermediate results.

        genewise_dispersions_shared_state: dict
            Shared state with a "genewise_dispersions" key.

        round_idx: int
            Index of the current round.

        clean_models: bool
            Whether to clean the models after the computation.

        Returns
        -------
        local_states: dict
            Local states. Required to propagate intermediate results.

        dispersion_trend_share_state: dict
            Shared states with:
            - "fitted_dispersions": the fitted dispersions,
            - "prior_disp_var": the prior dispersion variance.

        round_idx: int
            The updated round index.

        """
        # --- Return means and dispersions ---#
        # TODO : merge this step with the last steps from genewise dispersion
        local_states, shared_states, round_idx = local_step(
            local_method=self.get_local_mean_and_dispersion,
            train_data_nodes=train_data_nodes,
            output_local_states=local_states,
            round_idx=round_idx,
            input_local_states=local_states,
            input_shared_state=genewise_dispersions_shared_state,
            aggregation_id=aggregation_node.organization_id,
            description="Get local means and dispersions",
            clean_models=clean_models,
        )

        # ---- Fit dispersion trend ----#

        dispersion_trend_shared_state, round_idx = aggregation_step(
            aggregation_method=self.agg_fit_dispersion_trend_and_prior_dispersion,
            train_data_nodes=train_data_nodes,
            aggregation_node=aggregation_node,
            input_shared_states=shared_states,
            round_idx=round_idx,
            description="Fitting dispersion trend",
            clean_models=clean_models,
        )

        return local_states, dispersion_trend_shared_state, round_idx
compute_dispersion_prior(train_data_nodes, aggregation_node, local_states, genewise_dispersions_shared_state, round_idx, clean_models)

Fit the dispersion trend.

Parameters:

Name Type Description Default
train_data_nodes

list of TrainDataNode.

required
aggregation_node

The aggregation node.

required
local_states

Local states. Required to propagate intermediate results.

required
genewise_dispersions_shared_state

Shared state with a "genewise_dispersions" key.

required
round_idx

Index of the current round.

required
clean_models

Whether to clean the models after the computation.

required

Returns:

Name Type Description
local_states dict

Local states. Required to propagate intermediate results.

dispersion_trend_share_state dict

Shared states with: - "fitted_dispersions": the fitted dispersions, - "prior_disp_var": the prior dispersion variance.

round_idx int

The updated round index.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/compute_dispersion_prior.py
def compute_dispersion_prior(
    self,
    train_data_nodes,
    aggregation_node,
    local_states,
    genewise_dispersions_shared_state,
    round_idx,
    clean_models,
):
    """Fit the dispersion trend.

    Parameters
    ----------
    train_data_nodes: list
        list of TrainDataNode.

    aggregation_node: AggregationNode
        The aggregation node.

    local_states: dict
        Local states. Required to propagate intermediate results.

    genewise_dispersions_shared_state: dict
        Shared state with a "genewise_dispersions" key.

    round_idx: int
        Index of the current round.

    clean_models: bool
        Whether to clean the models after the computation.

    Returns
    -------
    local_states: dict
        Local states. Required to propagate intermediate results.

    dispersion_trend_share_state: dict
        Shared states with:
        - "fitted_dispersions": the fitted dispersions,
        - "prior_disp_var": the prior dispersion variance.

    round_idx: int
        The updated round index.

    """
    # --- Return means and dispersions ---#
    # TODO : merge this step with the last steps from genewise dispersion
    local_states, shared_states, round_idx = local_step(
        local_method=self.get_local_mean_and_dispersion,
        train_data_nodes=train_data_nodes,
        output_local_states=local_states,
        round_idx=round_idx,
        input_local_states=local_states,
        input_shared_state=genewise_dispersions_shared_state,
        aggregation_id=aggregation_node.organization_id,
        description="Get local means and dispersions",
        clean_models=clean_models,
    )

    # ---- Fit dispersion trend ----#

    dispersion_trend_shared_state, round_idx = aggregation_step(
        aggregation_method=self.agg_fit_dispersion_trend_and_prior_dispersion,
        train_data_nodes=train_data_nodes,
        aggregation_node=aggregation_node,
        input_shared_states=shared_states,
        round_idx=round_idx,
        description="Fitting dispersion trend",
        clean_models=clean_models,
    )

    return local_states, dispersion_trend_shared_state, round_idx

substeps

Module containing the substeps for the computation of size factors.

AggFitDispersionTrendAndPrior

Mixin class to implement the fit of the dispersion trend.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/substeps.py
class AggFitDispersionTrendAndPrior:
    """Mixin class to implement the fit of the dispersion trend."""

    min_disp: float

    @remote
    @log_remote
    def agg_fit_dispersion_trend_and_prior_dispersion(self, shared_states):
        """
        Fit the dispersion trend, and compute the dispersion prior.

        Parameters
        ----------
        shared_states : dict
            Shared states from the local step with the following keys:
            - genewise_dispersions: np.ndarray of shape (n_genes,)
            - n_params: int
            - non_zero: np.ndarray of shape (n_genes,)
            - mean_normed_counts: np.ndarray of shape (n_genes,)
            - n_obs: int

        Returns
        -------
        dict
            dict with the following keys:
            - prior_disp_var: float
                The prior dispersion variance.
            - _squared_logres: float
                The squared log-residuals.
            - trend_coeffs: np.ndarray of shape (2,)
                The coefficients of the parametric dispersion trend.
            - fitted_dispersions: np.ndarray of shape (n_genes,)
                The fitted dispersions, computed from the dispersion trend.
            - disp_function_type: str
                The type of dispersion function (parametric or mean).
            - mean_disp: float, optional
                The mean dispersion (if "mean" fit type).

        """
        genewise_dispersions = shared_states[0]["genewise_dispersions"]
        n_params = shared_states[0]["n_params"]
        non_zero = shared_states[0]["non_zero"]
        n_total_obs = sum([state["n_obs"] for state in shared_states])
        mean_normed_counts = (
            sum(
                [
                    state["mean_normed_counts"] * state["n_obs"]
                    for state in shared_states
                ]
            )
            / n_total_obs
        )

        # Exclude all-zero counts
        targets = pd.Series(
            genewise_dispersions.copy(),
        )
        targets = targets[non_zero]
        covariates = pd.Series(1 / mean_normed_counts[non_zero], index=targets.index)

        for gene in targets.index:
            if (
                np.isinf(covariates.loc[gene]).any()
                or np.isnan(covariates.loc[gene]).any()
            ):
                targets.drop(labels=[gene], inplace=True)
                covariates.drop(labels=[gene], inplace=True)

        # Initialize coefficients
        old_coeffs = pd.Series([0.1, 0.1])
        coeffs = pd.Series([1.0, 1.0])
        mean_disp = None

        disp_function_type = "parametric"
        while (coeffs > 1e-10).all() and (
            np.log(np.abs(coeffs / old_coeffs)) ** 2
        ).sum() >= 1e-6:
            old_coeffs = coeffs
            (
                coeffs,
                predictions,
                converged,
            ) = DefaultInference().dispersion_trend_gamma_glm(covariates, targets)

            if not converged or (coeffs <= 1e-10).any():
                warnings.warn(
                    "The dispersion trend curve fitting did not converge. "
                    "Switching to a mean-based dispersion trend.",
                    UserWarning,
                    stacklevel=2,
                )
                mean_disp = trim_mean(
                    genewise_dispersions[genewise_dispersions > 10 * self.min_disp],
                    proportiontocut=0.001,
                )
                disp_function_type = "mean"

            pred_ratios = genewise_dispersions[covariates.index] / predictions

            targets.drop(
                targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
                inplace=True,
            )
            covariates.drop(
                covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
                inplace=True,
            )

        fitted_dispersions = np.full_like(genewise_dispersions, np.NaN)

        fitted_dispersions[non_zero] = disp_function(
            mean_normed_counts[non_zero],
            disp_function_type=disp_function_type,
            coeffs=coeffs,
            mean_disp=mean_disp,
        )

        disp_residuals = np.log(genewise_dispersions[non_zero]) - np.log(
            fitted_dispersions[non_zero]
        )

        # Compute squared log-residuals and prior variance based on genes whose
        # dispersions are above 100 * min_disp. This is to reproduce DESeq2's behaviour.
        above_min_disp = genewise_dispersions[non_zero] >= (100 * self.min_disp)

        _squared_logres = mean_absolute_deviation(disp_residuals[above_min_disp]) ** 2

        prior_disp_var = np.maximum(
            _squared_logres - polygamma(1, (n_total_obs - n_params) / 2),
            0.25,
        )

        return {
            "prior_disp_var": prior_disp_var,
            "_squared_logres": _squared_logres,
            "trend_coeffs": coeffs,
            "fitted_dispersions": fitted_dispersions,
            "disp_function_type": disp_function_type,
            "mean_disp": mean_disp,
        }
agg_fit_dispersion_trend_and_prior_dispersion(shared_states)

Fit the dispersion trend, and compute the dispersion prior.

Parameters:

Name Type Description Default
shared_states dict

Shared states from the local step with the following keys: - genewise_dispersions: np.ndarray of shape (n_genes,) - n_params: int - non_zero: np.ndarray of shape (n_genes,) - mean_normed_counts: np.ndarray of shape (n_genes,) - n_obs: int

required

Returns:

Type Description
dict

dict with the following keys: - prior_disp_var: float The prior dispersion variance. - _squared_logres: float The squared log-residuals. - trend_coeffs: np.ndarray of shape (2,) The coefficients of the parametric dispersion trend. - fitted_dispersions: np.ndarray of shape (n_genes,) The fitted dispersions, computed from the dispersion trend. - disp_function_type: str The type of dispersion function (parametric or mean). - mean_disp: float, optional The mean dispersion (if "mean" fit type).

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/substeps.py
@remote
@log_remote
def agg_fit_dispersion_trend_and_prior_dispersion(self, shared_states):
    """
    Fit the dispersion trend, and compute the dispersion prior.

    Parameters
    ----------
    shared_states : dict
        Shared states from the local step with the following keys:
        - genewise_dispersions: np.ndarray of shape (n_genes,)
        - n_params: int
        - non_zero: np.ndarray of shape (n_genes,)
        - mean_normed_counts: np.ndarray of shape (n_genes,)
        - n_obs: int

    Returns
    -------
    dict
        dict with the following keys:
        - prior_disp_var: float
            The prior dispersion variance.
        - _squared_logres: float
            The squared log-residuals.
        - trend_coeffs: np.ndarray of shape (2,)
            The coefficients of the parametric dispersion trend.
        - fitted_dispersions: np.ndarray of shape (n_genes,)
            The fitted dispersions, computed from the dispersion trend.
        - disp_function_type: str
            The type of dispersion function (parametric or mean).
        - mean_disp: float, optional
            The mean dispersion (if "mean" fit type).

    """
    genewise_dispersions = shared_states[0]["genewise_dispersions"]
    n_params = shared_states[0]["n_params"]
    non_zero = shared_states[0]["non_zero"]
    n_total_obs = sum([state["n_obs"] for state in shared_states])
    mean_normed_counts = (
        sum(
            [
                state["mean_normed_counts"] * state["n_obs"]
                for state in shared_states
            ]
        )
        / n_total_obs
    )

    # Exclude all-zero counts
    targets = pd.Series(
        genewise_dispersions.copy(),
    )
    targets = targets[non_zero]
    covariates = pd.Series(1 / mean_normed_counts[non_zero], index=targets.index)

    for gene in targets.index:
        if (
            np.isinf(covariates.loc[gene]).any()
            or np.isnan(covariates.loc[gene]).any()
        ):
            targets.drop(labels=[gene], inplace=True)
            covariates.drop(labels=[gene], inplace=True)

    # Initialize coefficients
    old_coeffs = pd.Series([0.1, 0.1])
    coeffs = pd.Series([1.0, 1.0])
    mean_disp = None

    disp_function_type = "parametric"
    while (coeffs > 1e-10).all() and (
        np.log(np.abs(coeffs / old_coeffs)) ** 2
    ).sum() >= 1e-6:
        old_coeffs = coeffs
        (
            coeffs,
            predictions,
            converged,
        ) = DefaultInference().dispersion_trend_gamma_glm(covariates, targets)

        if not converged or (coeffs <= 1e-10).any():
            warnings.warn(
                "The dispersion trend curve fitting did not converge. "
                "Switching to a mean-based dispersion trend.",
                UserWarning,
                stacklevel=2,
            )
            mean_disp = trim_mean(
                genewise_dispersions[genewise_dispersions > 10 * self.min_disp],
                proportiontocut=0.001,
            )
            disp_function_type = "mean"

        pred_ratios = genewise_dispersions[covariates.index] / predictions

        targets.drop(
            targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
            inplace=True,
        )
        covariates.drop(
            covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
            inplace=True,
        )

    fitted_dispersions = np.full_like(genewise_dispersions, np.NaN)

    fitted_dispersions[non_zero] = disp_function(
        mean_normed_counts[non_zero],
        disp_function_type=disp_function_type,
        coeffs=coeffs,
        mean_disp=mean_disp,
    )

    disp_residuals = np.log(genewise_dispersions[non_zero]) - np.log(
        fitted_dispersions[non_zero]
    )

    # Compute squared log-residuals and prior variance based on genes whose
    # dispersions are above 100 * min_disp. This is to reproduce DESeq2's behaviour.
    above_min_disp = genewise_dispersions[non_zero] >= (100 * self.min_disp)

    _squared_logres = mean_absolute_deviation(disp_residuals[above_min_disp]) ** 2

    prior_disp_var = np.maximum(
        _squared_logres - polygamma(1, (n_total_obs - n_params) / 2),
        0.25,
    )

    return {
        "prior_disp_var": prior_disp_var,
        "_squared_logres": _squared_logres,
        "trend_coeffs": coeffs,
        "fitted_dispersions": fitted_dispersions,
        "disp_function_type": disp_function_type,
        "mean_disp": mean_disp,
    }

LocGetMeanDispersionAndMean

Mixin to get the local mean and dispersion.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/substeps.py
class LocGetMeanDispersionAndMean:
    """Mixin to get the local mean and dispersion."""

    local_adata: ad.AnnData

    @remote_data
    @log_remote_data
    @reconstruct_adatas
    def get_local_mean_and_dispersion(
        self,
        data_from_opener,
        shared_state: dict,
    ) -> dict:
        # pylint: disable=unused-argument
        """Return local gene means and dispersion.

        Parameters
        ----------
        data_from_opener : ad.AnnData
            AnnData returned by the opener. Not used.

        shared_state : dict
            Shared state returned by the last step of gene-wise dispersion computation.
            Contains a "genewise_dispersions" key with the gene-wise dispersions.

        Returns
        -------
        dict
            Local results to be shared via shared_state to the aggregation node. dict
            with the following keys:
            - mean_normed_counts: np.ndarray[float] of shape (n_genes,)
                The mean normed counts.
            - n_obs: int
                The number of observations.
            - non_zero: np.ndarray[bool] of shape (n_genes,)
                Mask of the genes with non zero counts.
            - genewise_dispersions: np.ndarray[float] of shape (n_genes,)
                The genewise dispersions.
            - num_vars: int
                The number of variables.

        """
        # Save gene-wise dispersions from the previous step.
        # Dispersions of all-zero genes should already be NaN.
        self.local_adata.varm["genewise_dispersions"] = shared_state[
            "genewise_dispersions"
        ]

        # TODO: these could be gathered earlier and sent directly to the aggregation
        # node.
        return {
            "mean_normed_counts": self.local_adata.layers["normed_counts"].mean(0),
            "n_obs": self.local_adata.n_obs,
            "non_zero": self.local_adata.varm["non_zero"],
            "genewise_dispersions": self.local_adata.varm["genewise_dispersions"],
            "n_params": self.local_adata.uns["n_params"],
        }
get_local_mean_and_dispersion(data_from_opener, shared_state)

Return local gene means and dispersion.

Parameters:

Name Type Description Default
data_from_opener AnnData

AnnData returned by the opener. Not used.

required
shared_state dict

Shared state returned by the last step of gene-wise dispersion computation. Contains a "genewise_dispersions" key with the gene-wise dispersions.

required

Returns:

Type Description
dict

Local results to be shared via shared_state to the aggregation node. dict with the following keys: - mean_normed_counts: np.ndarray[float] of shape (n_genes,) The mean normed counts. - n_obs: int The number of observations. - non_zero: np.ndarray[bool] of shape (n_genes,) Mask of the genes with non zero counts. - genewise_dispersions: np.ndarray[float] of shape (n_genes,) The genewise dispersions. - num_vars: int The number of variables.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/substeps.py
@remote_data
@log_remote_data
@reconstruct_adatas
def get_local_mean_and_dispersion(
    self,
    data_from_opener,
    shared_state: dict,
) -> dict:
    # pylint: disable=unused-argument
    """Return local gene means and dispersion.

    Parameters
    ----------
    data_from_opener : ad.AnnData
        AnnData returned by the opener. Not used.

    shared_state : dict
        Shared state returned by the last step of gene-wise dispersion computation.
        Contains a "genewise_dispersions" key with the gene-wise dispersions.

    Returns
    -------
    dict
        Local results to be shared via shared_state to the aggregation node. dict
        with the following keys:
        - mean_normed_counts: np.ndarray[float] of shape (n_genes,)
            The mean normed counts.
        - n_obs: int
            The number of observations.
        - non_zero: np.ndarray[bool] of shape (n_genes,)
            Mask of the genes with non zero counts.
        - genewise_dispersions: np.ndarray[float] of shape (n_genes,)
            The genewise dispersions.
        - num_vars: int
            The number of variables.

    """
    # Save gene-wise dispersions from the previous step.
    # Dispersions of all-zero genes should already be NaN.
    self.local_adata.varm["genewise_dispersions"] = shared_state[
        "genewise_dispersions"
    ]

    # TODO: these could be gathered earlier and sent directly to the aggregation
    # node.
    return {
        "mean_normed_counts": self.local_adata.layers["normed_counts"].mean(0),
        "n_obs": self.local_adata.n_obs,
        "non_zero": self.local_adata.varm["non_zero"],
        "genewise_dispersions": self.local_adata.varm["genewise_dispersions"],
        "n_params": self.local_adata.uns["n_params"],
    }

LocUpdateFittedDispersions

Mixin to update the fitted dispersions after replacing outliers.

To use in refit mode only

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/substeps.py
class LocUpdateFittedDispersions:
    """Mixin to update the fitted dispersions after replacing outliers.

    To use in refit mode only
    """

    local_adata: ad.AnnData
    refit_adata: ad.AnnData

    @remote_data
    @log_remote_data
    @reconstruct_adatas
    def loc_update_fitted_dispersions(
        self,
        data_from_opener,
        shared_state: dict,
    ) -> None:
        """
        Update the fitted dispersions after replacing outliers.

        Parameters
        ----------
        data_from_opener : ad.AnnData
            AnnData returned by the opener. Not used.

        shared_state : dict
            A dictionary with a "fitted_dispersions" key, containing the dispersions
            fitted before replacing the outliers.
        """
        # Start by updating gene-wise dispersions
        self.refit_adata.varm["genewise_dispersions"] = shared_state[
            "genewise_dispersions"
        ]

        # Update the fitted dispersions
        non_zero = self.refit_adata.varm["non_zero"]
        self.refit_adata.uns["disp_function_type"] = self.local_adata.uns[
            "disp_function_type"
        ]

        fitted_dispersions = np.full_like(
            self.refit_adata.varm["genewise_dispersions"], np.NaN
        )

        fitted_dispersions[non_zero] = disp_function(
            self.refit_adata.varm["_normed_means"][non_zero],
            disp_function_type=self.refit_adata.uns["disp_function_type"],
            coeffs=self.refit_adata.uns["trend_coeffs"],
            mean_disp=self.refit_adata.uns["mean_disp"]
            if self.refit_adata.uns["disp_function_type"] == "parametric"
            else None,
        )

        self.refit_adata.varm["fitted_dispersions"] = fitted_dispersions
loc_update_fitted_dispersions(data_from_opener, shared_state)

Update the fitted dispersions after replacing outliers.

Parameters:

Name Type Description Default
data_from_opener AnnData

AnnData returned by the opener. Not used.

required
shared_state dict

A dictionary with a "fitted_dispersions" key, containing the dispersions fitted before replacing the outliers.

required
Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/substeps.py
@remote_data
@log_remote_data
@reconstruct_adatas
def loc_update_fitted_dispersions(
    self,
    data_from_opener,
    shared_state: dict,
) -> None:
    """
    Update the fitted dispersions after replacing outliers.

    Parameters
    ----------
    data_from_opener : ad.AnnData
        AnnData returned by the opener. Not used.

    shared_state : dict
        A dictionary with a "fitted_dispersions" key, containing the dispersions
        fitted before replacing the outliers.
    """
    # Start by updating gene-wise dispersions
    self.refit_adata.varm["genewise_dispersions"] = shared_state[
        "genewise_dispersions"
    ]

    # Update the fitted dispersions
    non_zero = self.refit_adata.varm["non_zero"]
    self.refit_adata.uns["disp_function_type"] = self.local_adata.uns[
        "disp_function_type"
    ]

    fitted_dispersions = np.full_like(
        self.refit_adata.varm["genewise_dispersions"], np.NaN
    )

    fitted_dispersions[non_zero] = disp_function(
        self.refit_adata.varm["_normed_means"][non_zero],
        disp_function_type=self.refit_adata.uns["disp_function_type"],
        coeffs=self.refit_adata.uns["trend_coeffs"],
        mean_disp=self.refit_adata.uns["mean_disp"]
        if self.refit_adata.uns["disp_function_type"] == "parametric"
        else None,
    )

    self.refit_adata.varm["fitted_dispersions"] = fitted_dispersions

utils

disp_function(x, disp_function_type, coeffs=None, mean_disp=None)

Return the dispersion trend function at x.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_dispersion_prior/utils.py
def disp_function(
    x,
    disp_function_type,
    coeffs: Union["pd.Series[float]", np.ndarray] | None = None,
    mean_disp: float | None = None,
) -> float | np.ndarray:
    """Return the dispersion trend function at x."""
    if disp_function_type == "parametric":
        assert coeffs is not None, "coeffs must be provided for parametric dispersion."
        return dispersion_trend(x, coeffs=coeffs)
    elif disp_function_type == "mean":
        assert mean_disp is not None, "mean_disp must be provided for mean dispersion."
        return np.full_like(x, mean_disp)
    else:
        raise ValueError(
            "disp_function_type must be 'parametric' or 'mean',"
            f" got {disp_function_type}"
        )

compute_genewise_dispersions

Module containing the mixin class to compute genewise dispersions.

compute_MoM_dispersions

Module to implement the computation of MoM dispersions.

compute_MoM_dispersions

Main module to compute method of moments (MoM) dispersions.

ComputeMoMDispersions

Bases: ComputeRoughDispersions, LocInvSizeMean, AggMomentsDispersion

Mixin class to implement the computation of MoM dispersions.

Relies on the ComputeRoughDispersions class, in addition to substeps.

Methods:

Name Description
compute_MoM_dispersions

The method to compute the MoM dispersions, that must be used in the main pipeline.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/compute_MoM_dispersions.py
class ComputeMoMDispersions(
    ComputeRoughDispersions,
    LocInvSizeMean,
    AggMomentsDispersion,
):
    """Mixin class to implement the computation of MoM dispersions.

    Relies on the ComputeRoughDispersions class, in addition to substeps.

    Methods
    -------
    compute_MoM_dispersions
        The method to compute the MoM dispersions, that must be used in the main
        pipeline.

    """

    def compute_MoM_dispersions(
        self,
        train_data_nodes,
        aggregation_node,
        local_states,
        gram_features_shared_states,
        round_idx,
        clean_models,
        refit_mode: bool = False,
    ):
        """Compute method of moments dispersions.

        Parameters
        ----------
        train_data_nodes: list
            List of TrainDataNode.

        aggregation_node: AggregationNode
            The aggregation node.

        local_states: dict
            Local states. Required to propagate intermediate results.

        gram_features_shared_states: list
            The list of shared states outputed by the compute_size_factors step.
            They contain a "local_gram_matrix" and a "local_features" fields.

        round_idx: int
            The current round.

        clean_models: bool
            Whether to clean the models after the computation.

        refit_mode: bool
            Whether to run the pipeline in refit mode, after cooks outliers were
            replaced. If True, the pipeline will be run on `refit_adata`s instead of
            `local_adata`s (default: False).

        Returns
        -------
        local_states: dict
            Local states. Required to propagate intermediate results.

        mom_dispersions_shared_state: dict
            Shared states containing MoM dispersions.

        round_idx: int
            The updated round number.

        """
        ###### Fit rough dispersions ######

        local_states, shared_states, round_idx = self.compute_rough_dispersions(
            train_data_nodes,
            aggregation_node,
            local_states,
            gram_features_shared_states=gram_features_shared_states,
            round_idx=round_idx,
            clean_models=clean_models,
            refit_mode=refit_mode,
        )

        ###### Compute moments dispersions ######

        # ---- Compute local means for moments dispersions---- #

        local_states, shared_states, round_idx = local_step(
            local_method=self.local_inverse_size_mean,
            train_data_nodes=train_data_nodes,
            output_local_states=local_states,
            round_idx=round_idx,
            input_local_states=local_states,
            input_shared_state=shared_states,
            aggregation_id=aggregation_node.organization_id,
            description="Compute local inverse size factor means.",
            clean_models=clean_models,
            method_params={"refit_mode": refit_mode},
        )

        # ---- Compute moments dispersions and merge to get MoM dispersions ---- #

        mom_dispersions_shared_state, round_idx = aggregation_step(
            aggregation_method=self.aggregate_moments_dispersions,
            train_data_nodes=train_data_nodes,
            aggregation_node=aggregation_node,
            input_shared_states=shared_states,
            round_idx=round_idx,
            description="Compute global MoM dispersions",
            clean_models=clean_models,
        )

        return local_states, mom_dispersions_shared_state, round_idx
compute_MoM_dispersions(train_data_nodes, aggregation_node, local_states, gram_features_shared_states, round_idx, clean_models, refit_mode=False)

Compute method of moments dispersions.

Parameters:

Name Type Description Default
train_data_nodes

List of TrainDataNode.

required
aggregation_node

The aggregation node.

required
local_states

Local states. Required to propagate intermediate results.

required
gram_features_shared_states

The list of shared states outputed by the compute_size_factors step. They contain a "local_gram_matrix" and a "local_features" fields.

required
round_idx

The current round.

required
clean_models

Whether to clean the models after the computation.

required
refit_mode bool

Whether to run the pipeline in refit mode, after cooks outliers were replaced. If True, the pipeline will be run on refit_adatas instead of local_adatas (default: False).

False

Returns:

Name Type Description
local_states dict

Local states. Required to propagate intermediate results.

mom_dispersions_shared_state dict

Shared states containing MoM dispersions.

round_idx int

The updated round number.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/compute_MoM_dispersions.py
def compute_MoM_dispersions(
    self,
    train_data_nodes,
    aggregation_node,
    local_states,
    gram_features_shared_states,
    round_idx,
    clean_models,
    refit_mode: bool = False,
):
    """Compute method of moments dispersions.

    Parameters
    ----------
    train_data_nodes: list
        List of TrainDataNode.

    aggregation_node: AggregationNode
        The aggregation node.

    local_states: dict
        Local states. Required to propagate intermediate results.

    gram_features_shared_states: list
        The list of shared states outputed by the compute_size_factors step.
        They contain a "local_gram_matrix" and a "local_features" fields.

    round_idx: int
        The current round.

    clean_models: bool
        Whether to clean the models after the computation.

    refit_mode: bool
        Whether to run the pipeline in refit mode, after cooks outliers were
        replaced. If True, the pipeline will be run on `refit_adata`s instead of
        `local_adata`s (default: False).

    Returns
    -------
    local_states: dict
        Local states. Required to propagate intermediate results.

    mom_dispersions_shared_state: dict
        Shared states containing MoM dispersions.

    round_idx: int
        The updated round number.

    """
    ###### Fit rough dispersions ######

    local_states, shared_states, round_idx = self.compute_rough_dispersions(
        train_data_nodes,
        aggregation_node,
        local_states,
        gram_features_shared_states=gram_features_shared_states,
        round_idx=round_idx,
        clean_models=clean_models,
        refit_mode=refit_mode,
    )

    ###### Compute moments dispersions ######

    # ---- Compute local means for moments dispersions---- #

    local_states, shared_states, round_idx = local_step(
        local_method=self.local_inverse_size_mean,
        train_data_nodes=train_data_nodes,
        output_local_states=local_states,
        round_idx=round_idx,
        input_local_states=local_states,
        input_shared_state=shared_states,
        aggregation_id=aggregation_node.organization_id,
        description="Compute local inverse size factor means.",
        clean_models=clean_models,
        method_params={"refit_mode": refit_mode},
    )

    # ---- Compute moments dispersions and merge to get MoM dispersions ---- #

    mom_dispersions_shared_state, round_idx = aggregation_step(
        aggregation_method=self.aggregate_moments_dispersions,
        train_data_nodes=train_data_nodes,
        aggregation_node=aggregation_node,
        input_shared_states=shared_states,
        round_idx=round_idx,
        description="Compute global MoM dispersions",
        clean_models=clean_models,
    )

    return local_states, mom_dispersions_shared_state, round_idx

compute_rough_dispersions

Module to compute rough dispersions.

ComputeRoughDispersions

Bases: AggRoughDispersion, LocRoughDispersion, AggCreateRoughDispersionsSystem

Mixin class to implement the computation of rough dispersions.

Methods:

Name Description
compute_rough_dispersions

The method to compute the rough dispersions, that must be used in the main pipeline.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/compute_rough_dispersions.py
class ComputeRoughDispersions(
    AggRoughDispersion,
    LocRoughDispersion,
    AggCreateRoughDispersionsSystem,
):
    """Mixin class to implement the computation of rough dispersions.

    Methods
    -------
    compute_rough_dispersions
        The method to compute the rough dispersions, that must be used in the main
        pipeline.

    """

    def compute_rough_dispersions(
        self,
        train_data_nodes,
        aggregation_node,
        local_states,
        gram_features_shared_states,
        round_idx,
        clean_models,
        refit_mode: bool = False,
    ):
        """Compute rough dispersions.

        Parameters
        ----------
        train_data_nodes: list
            List of TrainDataNode.

        aggregation_node: AggregationNode
            The aggregation node.

        local_states: dict
            Local states. Required to propagate intermediate results.

        gram_features_shared_states: list
            The list of shared states outputed by the compute_size_factors step.
            They contain a "local_gram_matrix" and a "local_features" fields.

        round_idx: int
            The current round.

        clean_models: bool
            Whether to clean the models after the computation.

        refit_mode: bool
            Whether to run the pipeline in refit mode, after cooks outliers were
            replaced. If True, the pipeline will be run on `refit_adata`s instead of
            `local_adata`s (default: False).

        Returns
        -------
        local_states: dict
            Local states. Required to propagate intermediate results.

        rough_dispersion_shared_state: dict
            Shared states containing rough dispersions.

        round_idx: int
            The updated round number.

        """
        # TODO: in refit mode, we need to gather the gram matrix and the features some
        #  way

        # ---- Solve global linear system ---- #

        rough_dispersion_system_shared_state, round_idx = aggregation_step(
            aggregation_method=self.create_rough_dispersions_system,
            train_data_nodes=train_data_nodes,
            aggregation_node=aggregation_node,
            input_shared_states=gram_features_shared_states,
            round_idx=round_idx,
            description="Solving system for rough dispersions",
            clean_models=clean_models,
            method_params={"refit_mode": refit_mode},
        )

        # ---- Compute local rough dispersions---- #

        local_states, shared_states, round_idx = local_step(
            local_method=self.local_rough_dispersions,
            train_data_nodes=train_data_nodes,
            output_local_states=local_states,
            round_idx=round_idx,
            input_local_states=local_states,
            input_shared_state=rough_dispersion_system_shared_state,
            aggregation_id=aggregation_node.organization_id,
            description="Computing local rough dispersions",
            clean_models=clean_models,
            method_params={"refit_mode": refit_mode},
        )

        # ---- Compute global rough dispersions---- #

        rough_dispersion_shared_state, round_idx = aggregation_step(
            aggregation_method=self.aggregate_rough_dispersions,
            train_data_nodes=train_data_nodes,
            aggregation_node=aggregation_node,
            input_shared_states=shared_states,
            round_idx=round_idx,
            description="Compute global rough dispersions",
            clean_models=clean_models,
        )

        return local_states, rough_dispersion_shared_state, round_idx
compute_rough_dispersions(train_data_nodes, aggregation_node, local_states, gram_features_shared_states, round_idx, clean_models, refit_mode=False)

Compute rough dispersions.

Parameters:

Name Type Description Default
train_data_nodes

List of TrainDataNode.

required
aggregation_node

The aggregation node.

required
local_states

Local states. Required to propagate intermediate results.

required
gram_features_shared_states

The list of shared states outputed by the compute_size_factors step. They contain a "local_gram_matrix" and a "local_features" fields.

required
round_idx

The current round.

required
clean_models

Whether to clean the models after the computation.

required
refit_mode bool

Whether to run the pipeline in refit mode, after cooks outliers were replaced. If True, the pipeline will be run on refit_adatas instead of local_adatas (default: False).

False

Returns:

Name Type Description
local_states dict

Local states. Required to propagate intermediate results.

rough_dispersion_shared_state dict

Shared states containing rough dispersions.

round_idx int

The updated round number.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/compute_rough_dispersions.py
def compute_rough_dispersions(
    self,
    train_data_nodes,
    aggregation_node,
    local_states,
    gram_features_shared_states,
    round_idx,
    clean_models,
    refit_mode: bool = False,
):
    """Compute rough dispersions.

    Parameters
    ----------
    train_data_nodes: list
        List of TrainDataNode.

    aggregation_node: AggregationNode
        The aggregation node.

    local_states: dict
        Local states. Required to propagate intermediate results.

    gram_features_shared_states: list
        The list of shared states outputed by the compute_size_factors step.
        They contain a "local_gram_matrix" and a "local_features" fields.

    round_idx: int
        The current round.

    clean_models: bool
        Whether to clean the models after the computation.

    refit_mode: bool
        Whether to run the pipeline in refit mode, after cooks outliers were
        replaced. If True, the pipeline will be run on `refit_adata`s instead of
        `local_adata`s (default: False).

    Returns
    -------
    local_states: dict
        Local states. Required to propagate intermediate results.

    rough_dispersion_shared_state: dict
        Shared states containing rough dispersions.

    round_idx: int
        The updated round number.

    """
    # TODO: in refit mode, we need to gather the gram matrix and the features some
    #  way

    # ---- Solve global linear system ---- #

    rough_dispersion_system_shared_state, round_idx = aggregation_step(
        aggregation_method=self.create_rough_dispersions_system,
        train_data_nodes=train_data_nodes,
        aggregation_node=aggregation_node,
        input_shared_states=gram_features_shared_states,
        round_idx=round_idx,
        description="Solving system for rough dispersions",
        clean_models=clean_models,
        method_params={"refit_mode": refit_mode},
    )

    # ---- Compute local rough dispersions---- #

    local_states, shared_states, round_idx = local_step(
        local_method=self.local_rough_dispersions,
        train_data_nodes=train_data_nodes,
        output_local_states=local_states,
        round_idx=round_idx,
        input_local_states=local_states,
        input_shared_state=rough_dispersion_system_shared_state,
        aggregation_id=aggregation_node.organization_id,
        description="Computing local rough dispersions",
        clean_models=clean_models,
        method_params={"refit_mode": refit_mode},
    )

    # ---- Compute global rough dispersions---- #

    rough_dispersion_shared_state, round_idx = aggregation_step(
        aggregation_method=self.aggregate_rough_dispersions,
        train_data_nodes=train_data_nodes,
        aggregation_node=aggregation_node,
        input_shared_states=shared_states,
        round_idx=round_idx,
        description="Compute global rough dispersions",
        clean_models=clean_models,
    )

    return local_states, rough_dispersion_shared_state, round_idx

substeps

Module to implement the substeps for the rough dispersions step.

This module contains all these substeps as mixin classes.

AggCreateRoughDispersionsSystem

Mixin to solve the linear system for rough dispersions.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/substeps.py
class AggCreateRoughDispersionsSystem:
    """Mixin to solve the linear system for rough dispersions."""

    @remote
    @log_remote
    def create_rough_dispersions_system(self, shared_states, refit_mode: bool = False):
        """Solve the linear system in for rough dispersions.

        Parameters
        ----------
        shared_states : list
            List of results (local_gram_matrix, local_features) from training nodes.

        refit_mode : bool
            Whether to run the pipeline in refit mode, after cooks outliers were
            replaced. If True, there is no need to compute the Gram matrix which was
            already computed in the compute_size_factors step (default: False).

        Returns
        -------
        dict
            The global feature vector and the global hat matrix if refit_mode is
            ``False``.
        """
        shared_state = {
            "global_feature_vector": sum(
                [state["local_features"] for state in shared_states]
            )
        }
        if not refit_mode:
            shared_state["global_gram_matrix"] = sum(
                [state["local_gram_matrix"] for state in shared_states]
            )

        return shared_state
create_rough_dispersions_system(shared_states, refit_mode=False)

Solve the linear system in for rough dispersions.

Parameters:

Name Type Description Default
shared_states list

List of results (local_gram_matrix, local_features) from training nodes.

required
refit_mode bool

Whether to run the pipeline in refit mode, after cooks outliers were replaced. If True, there is no need to compute the Gram matrix which was already computed in the compute_size_factors step (default: False).

False

Returns:

Type Description
dict

The global feature vector and the global hat matrix if refit_mode is False.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/substeps.py
@remote
@log_remote
def create_rough_dispersions_system(self, shared_states, refit_mode: bool = False):
    """Solve the linear system in for rough dispersions.

    Parameters
    ----------
    shared_states : list
        List of results (local_gram_matrix, local_features) from training nodes.

    refit_mode : bool
        Whether to run the pipeline in refit mode, after cooks outliers were
        replaced. If True, there is no need to compute the Gram matrix which was
        already computed in the compute_size_factors step (default: False).

    Returns
    -------
    dict
        The global feature vector and the global hat matrix if refit_mode is
        ``False``.
    """
    shared_state = {
        "global_feature_vector": sum(
            [state["local_features"] for state in shared_states]
        )
    }
    if not refit_mode:
        shared_state["global_gram_matrix"] = sum(
            [state["local_gram_matrix"] for state in shared_states]
        )

    return shared_state
AggMomentsDispersion

Mixin to compute MoM dispersions.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/substeps.py
class AggMomentsDispersion:
    """Mixin to compute MoM dispersions."""

    local_adata: AnnData
    max_disp: float
    min_disp: float

    @remote
    @log_remote
    def aggregate_moments_dispersions(self, shared_states):
        """Compute global moments dispersions.

        Parameters
        ----------
        shared_states : list
            List of results (local_inverse_size_mean, local_counts_mean,
            local_squared_squared_mean, local_n_obs, rough_dispersions)
            from training nodes.

        Returns
        -------
        dict
            Global moments dispersions, the mask of all zero genes, the total
            number of samples (used to set max_disp and lr), and
            the total normed counts mean (used in the independent filtering
            step).
        """
        tot_n_obs = sum([state["local_n_obs"] for state in shared_states])

        # Compute the mean of inverse size factors
        tot_inv_size_mean = (
            sum(
                [
                    state["local_n_obs"] * state["local_inverse_size_mean"]
                    for state in shared_states
                ]
            )
            / tot_n_obs
        )

        # Compute the mean and variance of normalized counts

        tot_counts_mean = (
            sum(
                [
                    state["local_n_obs"] * state["local_counts_mean"]
                    for state in shared_states
                ]
            )
            / tot_n_obs
        )
        non_zero = tot_counts_mean != 0

        tot_squared_mean = (
            sum(
                [
                    state["local_n_obs"] * state["local_squared_squared_mean"]
                    for state in shared_states
                ]
            )
            / tot_n_obs
        )

        counts_variance = (
            tot_n_obs / (tot_n_obs - 1) * (tot_squared_mean - tot_counts_mean**2)
        )

        moments_dispersions = np.zeros(
            counts_variance.shape, dtype=counts_variance.dtype
        )
        moments_dispersions[non_zero] = (
            counts_variance[non_zero] - tot_inv_size_mean * tot_counts_mean[non_zero]
        ) / tot_counts_mean[non_zero] ** 2

        # Get rough dispersions from the first center
        rough_dispersions = shared_states[0]["rough_dispersions"]

        # Compute the maximum dispersion
        max_disp = np.maximum(self.max_disp, tot_n_obs)

        # Return moment estimate
        alpha_hat = np.minimum(rough_dispersions, moments_dispersions)
        MoM_dispersions = np.clip(alpha_hat, self.min_disp, max_disp)

        # Set MoM dispersions of all zero genes to NaN

        MoM_dispersions[~non_zero] = np.nan
        return {
            "MoM_dispersions": MoM_dispersions,
            "non_zero": non_zero,
            "tot_num_samples": tot_n_obs,
            "tot_counts_mean": tot_counts_mean,
        }
aggregate_moments_dispersions(shared_states)

Compute global moments dispersions.

Parameters:

Name Type Description Default
shared_states list

List of results (local_inverse_size_mean, local_counts_mean, local_squared_squared_mean, local_n_obs, rough_dispersions) from training nodes.

required

Returns:

Type Description
dict

Global moments dispersions, the mask of all zero genes, the total number of samples (used to set max_disp and lr), and the total normed counts mean (used in the independent filtering step).

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/substeps.py
@remote
@log_remote
def aggregate_moments_dispersions(self, shared_states):
    """Compute global moments dispersions.

    Parameters
    ----------
    shared_states : list
        List of results (local_inverse_size_mean, local_counts_mean,
        local_squared_squared_mean, local_n_obs, rough_dispersions)
        from training nodes.

    Returns
    -------
    dict
        Global moments dispersions, the mask of all zero genes, the total
        number of samples (used to set max_disp and lr), and
        the total normed counts mean (used in the independent filtering
        step).
    """
    tot_n_obs = sum([state["local_n_obs"] for state in shared_states])

    # Compute the mean of inverse size factors
    tot_inv_size_mean = (
        sum(
            [
                state["local_n_obs"] * state["local_inverse_size_mean"]
                for state in shared_states
            ]
        )
        / tot_n_obs
    )

    # Compute the mean and variance of normalized counts

    tot_counts_mean = (
        sum(
            [
                state["local_n_obs"] * state["local_counts_mean"]
                for state in shared_states
            ]
        )
        / tot_n_obs
    )
    non_zero = tot_counts_mean != 0

    tot_squared_mean = (
        sum(
            [
                state["local_n_obs"] * state["local_squared_squared_mean"]
                for state in shared_states
            ]
        )
        / tot_n_obs
    )

    counts_variance = (
        tot_n_obs / (tot_n_obs - 1) * (tot_squared_mean - tot_counts_mean**2)
    )

    moments_dispersions = np.zeros(
        counts_variance.shape, dtype=counts_variance.dtype
    )
    moments_dispersions[non_zero] = (
        counts_variance[non_zero] - tot_inv_size_mean * tot_counts_mean[non_zero]
    ) / tot_counts_mean[non_zero] ** 2

    # Get rough dispersions from the first center
    rough_dispersions = shared_states[0]["rough_dispersions"]

    # Compute the maximum dispersion
    max_disp = np.maximum(self.max_disp, tot_n_obs)

    # Return moment estimate
    alpha_hat = np.minimum(rough_dispersions, moments_dispersions)
    MoM_dispersions = np.clip(alpha_hat, self.min_disp, max_disp)

    # Set MoM dispersions of all zero genes to NaN

    MoM_dispersions[~non_zero] = np.nan
    return {
        "MoM_dispersions": MoM_dispersions,
        "non_zero": non_zero,
        "tot_num_samples": tot_n_obs,
        "tot_counts_mean": tot_counts_mean,
    }
AggRoughDispersion

Mixin to aggregate local rough dispersions.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/substeps.py
class AggRoughDispersion:
    """Mixin to aggregate local rough dispersions."""

    @remote
    @log_remote
    def aggregate_rough_dispersions(self, shared_states):
        """Aggregate local rough dispersions.

        Parameters
        ----------
        shared_states : list
            List of results (rough_dispersions, n_obs, n_params) from training nodes.

        Returns
        -------
        dict
            Global rough dispersions.
        """
        rough_dispersions = sum(
            [state["local_rough_dispersions"] for state in shared_states]
        )

        tot_obs = sum([state["local_n_obs"] for state in shared_states])
        n_params = shared_states[0]["local_n_params"]

        if tot_obs <= n_params:
            raise ValueError(
                "The number of samples is smaller or equal to the number of design "
                "variables, i.e., there are no replicates to estimate the "
                "dispersions. Please use a design with fewer variables."
            )

        return {
            "rough_dispersions": np.maximum(rough_dispersions / (tot_obs - n_params), 0)
        }
aggregate_rough_dispersions(shared_states)

Aggregate local rough dispersions.

Parameters:

Name Type Description Default
shared_states list

List of results (rough_dispersions, n_obs, n_params) from training nodes.

required

Returns:

Type Description
dict

Global rough dispersions.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/substeps.py
@remote
@log_remote
def aggregate_rough_dispersions(self, shared_states):
    """Aggregate local rough dispersions.

    Parameters
    ----------
    shared_states : list
        List of results (rough_dispersions, n_obs, n_params) from training nodes.

    Returns
    -------
    dict
        Global rough dispersions.
    """
    rough_dispersions = sum(
        [state["local_rough_dispersions"] for state in shared_states]
    )

    tot_obs = sum([state["local_n_obs"] for state in shared_states])
    n_params = shared_states[0]["local_n_params"]

    if tot_obs <= n_params:
        raise ValueError(
            "The number of samples is smaller or equal to the number of design "
            "variables, i.e., there are no replicates to estimate the "
            "dispersions. Please use a design with fewer variables."
        )

    return {
        "rough_dispersions": np.maximum(rough_dispersions / (tot_obs - n_params), 0)
    }
LocInvSizeMean

Mixin to compute local means of inverse size factors.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/substeps.py
class LocInvSizeMean:
    """Mixin to compute local means of inverse size factors."""

    local_adata: AnnData
    refit_adata: AnnData

    @remote_data
    @log_remote_data
    @reconstruct_adatas
    def local_inverse_size_mean(
        self, data_from_opener, shared_state=None, refit_mode: bool = False
    ) -> dict:
        """Compute local means of inverse size factors, counts, and squared counts.

        Parameters
        ----------
        data_from_opener : ad.AnnData
            AnnData returned by the opener. Not used.

        shared_state : dict
            Shared state containing rough dispersions from aggregator.

        refit_mode : bool
            Whether to run the pipeline in refit mode, after cooks outliers were
            replaced. If True, the pipeline will be run on `refit_adata`s instead of
            `local_adata`s (default: False).

        Returns
        -------
        dict
            dictionary containing all quantities required to compute MoM dispersions:
            local inverse size factor means, counts means, squared counts means,
            rough dispersions and number of samples.
        """
        if refit_mode:
            adata = self.refit_adata
        else:
            adata = self.local_adata

        adata.varm["_rough_dispersions"] = shared_state["rough_dispersions"]

        return {
            "local_inverse_size_mean": (1 / adata.obsm["size_factors"]).mean(),
            "local_counts_mean": adata.layers["normed_counts"].mean(0),
            "local_squared_squared_mean": (adata.layers["normed_counts"] ** 2).mean(0),
            "local_n_obs": adata.n_obs,
            # Pass rough dispersions to the aggregation node, to compute MoM dispersions
            "rough_dispersions": shared_state["rough_dispersions"],
        }
local_inverse_size_mean(data_from_opener, shared_state=None, refit_mode=False)

Compute local means of inverse size factors, counts, and squared counts.

Parameters:

Name Type Description Default
data_from_opener AnnData

AnnData returned by the opener. Not used.

required
shared_state dict

Shared state containing rough dispersions from aggregator.

None
refit_mode bool

Whether to run the pipeline in refit mode, after cooks outliers were replaced. If True, the pipeline will be run on refit_adatas instead of local_adatas (default: False).

False

Returns:

Type Description
dict

dictionary containing all quantities required to compute MoM dispersions: local inverse size factor means, counts means, squared counts means, rough dispersions and number of samples.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/substeps.py
@remote_data
@log_remote_data
@reconstruct_adatas
def local_inverse_size_mean(
    self, data_from_opener, shared_state=None, refit_mode: bool = False
) -> dict:
    """Compute local means of inverse size factors, counts, and squared counts.

    Parameters
    ----------
    data_from_opener : ad.AnnData
        AnnData returned by the opener. Not used.

    shared_state : dict
        Shared state containing rough dispersions from aggregator.

    refit_mode : bool
        Whether to run the pipeline in refit mode, after cooks outliers were
        replaced. If True, the pipeline will be run on `refit_adata`s instead of
        `local_adata`s (default: False).

    Returns
    -------
    dict
        dictionary containing all quantities required to compute MoM dispersions:
        local inverse size factor means, counts means, squared counts means,
        rough dispersions and number of samples.
    """
    if refit_mode:
        adata = self.refit_adata
    else:
        adata = self.local_adata

    adata.varm["_rough_dispersions"] = shared_state["rough_dispersions"]

    return {
        "local_inverse_size_mean": (1 / adata.obsm["size_factors"]).mean(),
        "local_counts_mean": adata.layers["normed_counts"].mean(0),
        "local_squared_squared_mean": (adata.layers["normed_counts"] ** 2).mean(0),
        "local_n_obs": adata.n_obs,
        # Pass rough dispersions to the aggregation node, to compute MoM dispersions
        "rough_dispersions": shared_state["rough_dispersions"],
    }
LocRoughDispersion

Mixin to compute local rough dispersions.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/substeps.py
class LocRoughDispersion:
    """Mixin to compute local rough dispersions."""

    local_adata: AnnData
    refit_adata: AnnData

    @remote_data
    @log_remote_data
    @reconstruct_adatas
    def local_rough_dispersions(
        self, data_from_opener, shared_state, refit_mode: bool = False
    ) -> dict:
        """Compute local rough dispersions, and save the global gram matrix.

        Parameters
        ----------
        data_from_opener : ad.AnnData
            AnnData returned by the opener. Not used.

        shared_state : dict
            Shared state containing
                - the gram matrix, if refit_mode is ``False``,
                - the global feature vector.

        refit_mode : bool
            Whether to run the pipeline in refit mode, after cooks outliers were
            replaced. If True, the pipeline will be run on `refit_adata`s instead of
            `local_adata`s (default: False).

        Returns
        -------
        dict
            Dictionary containing local rough dispersions, number of samples and
            number of parameters (i.e. number of columns in the design matrix).
        """
        if not refit_mode:
            global_gram_matrix = shared_state["global_gram_matrix"]
            self.local_adata.uns["_global_gram_matrix"] = global_gram_matrix
        else:
            global_gram_matrix = self.local_adata.uns["_global_gram_matrix"]

        beta_rough_dispersions = np.linalg.solve(
            global_gram_matrix, shared_state["global_feature_vector"]
        )

        if refit_mode:
            adata = self.refit_adata
        else:
            adata = self.local_adata
        adata.varm["_beta_rough_dispersions"] = beta_rough_dispersions.T
        # Save the rough dispersions beta so that we can reconstruct y_hat
        set_y_hat(adata)

        # Save global beta in the local data because so it can be used later in
        # fit_lin_mu. Do it before clipping.

        y_hat = np.maximum(adata.layers["_y_hat"], 1)
        unnormed_alpha_rde = (
            ((adata.layers["normed_counts"] - y_hat) ** 2 - y_hat) / (y_hat**2)
        ).sum(0)
        return {
            "local_rough_dispersions": unnormed_alpha_rde,
            "local_n_obs": adata.n_obs,
            "local_n_params": adata.uns["n_params"],
        }
local_rough_dispersions(data_from_opener, shared_state, refit_mode=False)

Compute local rough dispersions, and save the global gram matrix.

Parameters:

Name Type Description Default
data_from_opener AnnData

AnnData returned by the opener. Not used.

required
shared_state dict

Shared state containing - the gram matrix, if refit_mode is False, - the global feature vector.

required
refit_mode bool

Whether to run the pipeline in refit mode, after cooks outliers were replaced. If True, the pipeline will be run on refit_adatas instead of local_adatas (default: False).

False

Returns:

Type Description
dict

Dictionary containing local rough dispersions, number of samples and number of parameters (i.e. number of columns in the design matrix).

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_MoM_dispersions/substeps.py
@remote_data
@log_remote_data
@reconstruct_adatas
def local_rough_dispersions(
    self, data_from_opener, shared_state, refit_mode: bool = False
) -> dict:
    """Compute local rough dispersions, and save the global gram matrix.

    Parameters
    ----------
    data_from_opener : ad.AnnData
        AnnData returned by the opener. Not used.

    shared_state : dict
        Shared state containing
            - the gram matrix, if refit_mode is ``False``,
            - the global feature vector.

    refit_mode : bool
        Whether to run the pipeline in refit mode, after cooks outliers were
        replaced. If True, the pipeline will be run on `refit_adata`s instead of
        `local_adata`s (default: False).

    Returns
    -------
    dict
        Dictionary containing local rough dispersions, number of samples and
        number of parameters (i.e. number of columns in the design matrix).
    """
    if not refit_mode:
        global_gram_matrix = shared_state["global_gram_matrix"]
        self.local_adata.uns["_global_gram_matrix"] = global_gram_matrix
    else:
        global_gram_matrix = self.local_adata.uns["_global_gram_matrix"]

    beta_rough_dispersions = np.linalg.solve(
        global_gram_matrix, shared_state["global_feature_vector"]
    )

    if refit_mode:
        adata = self.refit_adata
    else:
        adata = self.local_adata
    adata.varm["_beta_rough_dispersions"] = beta_rough_dispersions.T
    # Save the rough dispersions beta so that we can reconstruct y_hat
    set_y_hat(adata)

    # Save global beta in the local data because so it can be used later in
    # fit_lin_mu. Do it before clipping.

    y_hat = np.maximum(adata.layers["_y_hat"], 1)
    unnormed_alpha_rde = (
        ((adata.layers["normed_counts"] - y_hat) ** 2 - y_hat) / (y_hat**2)
    ).sum(0)
    return {
        "local_rough_dispersions": unnormed_alpha_rde,
        "local_n_obs": adata.n_obs,
        "local_n_params": adata.uns["n_params"],
    }

compute_genewise_dispersions

Main module to compute genewise dispersions.

ComputeGenewiseDispersions

Bases: ComputeDispersionsGridSearch, ComputeMoMDispersions, LocLinMu, GetNumReplicates, ComputeLFC, LocSetMuHat

Mixin class to implement the computation of both genewise and MAP dispersions.

The switch between genewise and MAP dispersions is done by setting the fit_mode argument in the fit_dispersions to either "MLE" or "MAP".

Methods:

Name Description
fit_gene_wise_dispersions

A method to fit gene-wise dispersions using a grid search. Performs four steps: 1. Compute the first dispersions estimates using a method of moments (MoM) approach. 2. Compute the number of replicates for each combination of factors. This step is necessary to compute the mean estimate in one case, and in downstream steps (cooks distance, etc). 3. Compute an estimate of the mean from these dispersions. 4. Fit the dispersions using a grid search.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_genewise_dispersions.py
class ComputeGenewiseDispersions(
    ComputeDispersionsGridSearch,
    ComputeMoMDispersions,
    LocLinMu,
    GetNumReplicates,
    ComputeLFC,
    LocSetMuHat,
):
    """
    Mixin class to implement the computation of both genewise and MAP dispersions.

    The switch between genewise and MAP dispersions is done by setting the `fit_mode`
    argument in the `fit_dispersions` to either "MLE" or "MAP".

    Methods
    -------
    fit_gene_wise_dispersions
        A method to fit gene-wise dispersions using a grid search.
        Performs four steps:
        1. Compute the first dispersions estimates using a
        method of moments (MoM) approach.
        2. Compute the number of replicates for each combination of factors.
        This step is necessary to compute the mean estimate in one case, and
        in downstream steps (cooks distance, etc).
        3. Compute an estimate of the mean from these dispersions.
        4. Fit the dispersions using a grid search.


    """

    def fit_genewise_dispersions(
        self,
        train_data_nodes,
        aggregation_node,
        local_states,
        gram_features_shared_states,
        round_idx,
        clean_models,
        refit_mode: bool = False,
    ):
        """Fit the gene-wise dispersions.

        Performs four steps:
        1. Compute the first dispersions estimates using a
        method of moments (MoM) approach.
        2. Compute the number of replicates for each combination of factors.
        This step is necessary to compute the mean estimate in one case, and
        in downstream steps (cooks distance, etc).
        3. Compute an estimate of the mean from these dispersions.
        4. Fit the dispersions using a grid search.

        Parameters
        ----------
        train_data_nodes: list
            List of TrainDataNode.

        aggregation_node: AggregationNode
            The aggregation node.

        local_states: dict
            Local states. Required to propagate intermediate results.

        gram_features_shared_states: list
            The list of shared states outputed by the compute_size_factors step.
            They contain a "local_gram_matrix" and a "local_features" fields.

        round_idx: int
            The current round.

        clean_models: bool
            Whether to clean the models after the computation.

        refit_mode: bool
            Whether to run the pipeline in refit mode, after cooks outliers were
            replaced. If True, the pipeline will be run on `refit_adata`s instead of
            `local_adata`s. (default: False).

        Returns
        -------
        local_states: dict
            Local states. Required to propagate intermediate results.

        shared_state: dict or list[dict]
            A dictionary containing:
            - "genewise_dispersions": The MLE dispersions, to be stored locally at
            - "lower_log_bounds": log lower bounds for the grid search (only used in
            internal loop),
            - "upper_log_bounds": log upper bounds for the grid search (only used in
            internal loop).

        round_idx: int
            The updated round index.
        """
        # ---- Compute MoM dispersions ---- #
        (
            local_states,
            mom_dispersions_shared_state,
            round_idx,
        ) = self.compute_MoM_dispersions(
            train_data_nodes,
            aggregation_node,
            local_states,
            gram_features_shared_states,
            round_idx,
            clean_models,
            refit_mode=refit_mode,
        )

        # ---- Compute the initial mu estimates ---- #

        # 1 - Compute the linear mu estimates.

        local_states, linear_shared_states, round_idx = local_step(
            local_method=self.fit_lin_mu,
            train_data_nodes=train_data_nodes,
            output_local_states=local_states,
            input_local_states=local_states,
            input_shared_state=mom_dispersions_shared_state,
            aggregation_id=aggregation_node.organization_id,
            description="Compute local linear mu estimates.",
            round_idx=round_idx,
            clean_models=clean_models,
            method_params={"refit_mode": refit_mode},
        )

        # 2 - Compute IRLS estimates.
        local_states, round_idx = self.compute_lfc(
            train_data_nodes,
            aggregation_node,
            local_states,
            round_idx,
            clean_models=clean_models,
            lfc_mode="mu_init",
            refit_mode=refit_mode,
        )

        # 3 - Compare the number of replicates to the number of design matrix columns
        # and decide whether to use the IRLS estimates or the linear estimates.

        # Compute the number of replicates
        local_states, round_idx = self.get_num_replicates(
            train_data_nodes,
            aggregation_node,
            local_states,
            round_idx,
            clean_models=clean_models,
        )

        local_states, shared_states, round_idx = local_step(
            local_method=self.set_mu_hat,
            train_data_nodes=train_data_nodes,
            output_local_states=local_states,
            input_local_states=local_states,
            input_shared_state=None,
            aggregation_id=aggregation_node.organization_id,
            description="Pick between linear and irls mu_hat.",
            round_idx=round_idx,
            clean_models=clean_models,
            method_params={"refit_mode": refit_mode},
        )

        # ---- Fit dispersions ---- #
        local_states, shared_state, round_idx = self.fit_dispersions(
            train_data_nodes,
            aggregation_node,
            local_states,
            shared_state=None,
            round_idx=round_idx,
            clean_models=clean_models,
            fit_mode="MLE",
            refit_mode=refit_mode,
        )

        return local_states, shared_state, round_idx
fit_genewise_dispersions(train_data_nodes, aggregation_node, local_states, gram_features_shared_states, round_idx, clean_models, refit_mode=False)

Fit the gene-wise dispersions.

Performs four steps: 1. Compute the first dispersions estimates using a method of moments (MoM) approach. 2. Compute the number of replicates for each combination of factors. This step is necessary to compute the mean estimate in one case, and in downstream steps (cooks distance, etc). 3. Compute an estimate of the mean from these dispersions. 4. Fit the dispersions using a grid search.

Parameters:

Name Type Description Default
train_data_nodes

List of TrainDataNode.

required
aggregation_node

The aggregation node.

required
local_states

Local states. Required to propagate intermediate results.

required
gram_features_shared_states

The list of shared states outputed by the compute_size_factors step. They contain a "local_gram_matrix" and a "local_features" fields.

required
round_idx

The current round.

required
clean_models

Whether to clean the models after the computation.

required
refit_mode bool

Whether to run the pipeline in refit mode, after cooks outliers were replaced. If True, the pipeline will be run on refit_adatas instead of local_adatas. (default: False).

False

Returns:

Name Type Description
local_states dict

Local states. Required to propagate intermediate results.

shared_state dict or list[dict]

A dictionary containing: - "genewise_dispersions": The MLE dispersions, to be stored locally at - "lower_log_bounds": log lower bounds for the grid search (only used in internal loop), - "upper_log_bounds": log upper bounds for the grid search (only used in internal loop).

round_idx int

The updated round index.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/compute_genewise_dispersions.py
def fit_genewise_dispersions(
    self,
    train_data_nodes,
    aggregation_node,
    local_states,
    gram_features_shared_states,
    round_idx,
    clean_models,
    refit_mode: bool = False,
):
    """Fit the gene-wise dispersions.

    Performs four steps:
    1. Compute the first dispersions estimates using a
    method of moments (MoM) approach.
    2. Compute the number of replicates for each combination of factors.
    This step is necessary to compute the mean estimate in one case, and
    in downstream steps (cooks distance, etc).
    3. Compute an estimate of the mean from these dispersions.
    4. Fit the dispersions using a grid search.

    Parameters
    ----------
    train_data_nodes: list
        List of TrainDataNode.

    aggregation_node: AggregationNode
        The aggregation node.

    local_states: dict
        Local states. Required to propagate intermediate results.

    gram_features_shared_states: list
        The list of shared states outputed by the compute_size_factors step.
        They contain a "local_gram_matrix" and a "local_features" fields.

    round_idx: int
        The current round.

    clean_models: bool
        Whether to clean the models after the computation.

    refit_mode: bool
        Whether to run the pipeline in refit mode, after cooks outliers were
        replaced. If True, the pipeline will be run on `refit_adata`s instead of
        `local_adata`s. (default: False).

    Returns
    -------
    local_states: dict
        Local states. Required to propagate intermediate results.

    shared_state: dict or list[dict]
        A dictionary containing:
        - "genewise_dispersions": The MLE dispersions, to be stored locally at
        - "lower_log_bounds": log lower bounds for the grid search (only used in
        internal loop),
        - "upper_log_bounds": log upper bounds for the grid search (only used in
        internal loop).

    round_idx: int
        The updated round index.
    """
    # ---- Compute MoM dispersions ---- #
    (
        local_states,
        mom_dispersions_shared_state,
        round_idx,
    ) = self.compute_MoM_dispersions(
        train_data_nodes,
        aggregation_node,
        local_states,
        gram_features_shared_states,
        round_idx,
        clean_models,
        refit_mode=refit_mode,
    )

    # ---- Compute the initial mu estimates ---- #

    # 1 - Compute the linear mu estimates.

    local_states, linear_shared_states, round_idx = local_step(
        local_method=self.fit_lin_mu,
        train_data_nodes=train_data_nodes,
        output_local_states=local_states,
        input_local_states=local_states,
        input_shared_state=mom_dispersions_shared_state,
        aggregation_id=aggregation_node.organization_id,
        description="Compute local linear mu estimates.",
        round_idx=round_idx,
        clean_models=clean_models,
        method_params={"refit_mode": refit_mode},
    )

    # 2 - Compute IRLS estimates.
    local_states, round_idx = self.compute_lfc(
        train_data_nodes,
        aggregation_node,
        local_states,
        round_idx,
        clean_models=clean_models,
        lfc_mode="mu_init",
        refit_mode=refit_mode,
    )

    # 3 - Compare the number of replicates to the number of design matrix columns
    # and decide whether to use the IRLS estimates or the linear estimates.

    # Compute the number of replicates
    local_states, round_idx = self.get_num_replicates(
        train_data_nodes,
        aggregation_node,
        local_states,
        round_idx,
        clean_models=clean_models,
    )

    local_states, shared_states, round_idx = local_step(
        local_method=self.set_mu_hat,
        train_data_nodes=train_data_nodes,
        output_local_states=local_states,
        input_local_states=local_states,
        input_shared_state=None,
        aggregation_id=aggregation_node.organization_id,
        description="Pick between linear and irls mu_hat.",
        round_idx=round_idx,
        clean_models=clean_models,
        method_params={"refit_mode": refit_mode},
    )

    # ---- Fit dispersions ---- #
    local_states, shared_state, round_idx = self.fit_dispersions(
        train_data_nodes,
        aggregation_node,
        local_states,
        shared_state=None,
        round_idx=round_idx,
        clean_models=clean_models,
        fit_mode="MLE",
        refit_mode=refit_mode,
    )

    return local_states, shared_state, round_idx

get_num_replicates

Module containing the mixin class to compute the number of replicates.

get_num_replicates

GetNumReplicates

Bases: LocGetDesignMatrixLevels, AggGetCountsLvlForCells, LocFinalizeCellCounts

Mixin class to get the number of replicates for each combination of factors.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/get_num_replicates.py
class GetNumReplicates(
    LocGetDesignMatrixLevels, AggGetCountsLvlForCells, LocFinalizeCellCounts
):
    """Mixin class to get the number of replicates for each combination of factors."""

    def get_num_replicates(
        self, train_data_nodes, aggregation_node, local_states, round_idx, clean_models
    ):
        """
        Compute the number of replicates for each combination of factors.

        Parameters
        ----------
        train_data_nodes: list
            List of TrainDataNode.

        aggregation_node: AggregationNode
            The aggregation node.

        local_states: dict
            Local states. Required to propagate intermediate results.

        round_idx: int
            Index of the current round.

        clean_models: bool
            Whether to clean the models after the computation.

        Returns
        -------
        local_states: dict
            Local states, to store the number of replicates and cell level codes.

        round_idx: int
            The updated round index.
        """
        local_states, shared_states, round_idx = local_step(
            local_method=self.loc_get_design_matrix_levels,
            train_data_nodes=train_data_nodes,
            output_local_states=local_states,
            input_local_states=local_states,
            input_shared_state=None,
            aggregation_id=aggregation_node.organization_id,
            description="Get local matrix design level",
            round_idx=round_idx,
            clean_models=clean_models,
        )
        counts_lvl_share_state, round_idx = aggregation_step(
            aggregation_method=self.agg_get_counts_lvl_for_cells,
            train_data_nodes=train_data_nodes,
            aggregation_node=aggregation_node,
            input_shared_states=shared_states,
            description="Compute counts level",
            round_idx=round_idx,
            clean_models=clean_models,
        )

        local_states, _, round_idx = local_step(
            local_method=self.loc_finalize_cell_counts,
            train_data_nodes=train_data_nodes,
            output_local_states=local_states,
            input_local_states=local_states,
            input_shared_state=counts_lvl_share_state,
            aggregation_id=aggregation_node.organization_id,
            description="Finalize cell counts",
            round_idx=round_idx,
            clean_models=clean_models,
        )

        return local_states, round_idx
get_num_replicates(train_data_nodes, aggregation_node, local_states, round_idx, clean_models)

Compute the number of replicates for each combination of factors.

Parameters:

Name Type Description Default
train_data_nodes

List of TrainDataNode.

required
aggregation_node

The aggregation node.

required
local_states

Local states. Required to propagate intermediate results.

required
round_idx

Index of the current round.

required
clean_models

Whether to clean the models after the computation.

required

Returns:

Name Type Description
local_states dict

Local states, to store the number of replicates and cell level codes.

round_idx int

The updated round index.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/get_num_replicates.py
def get_num_replicates(
    self, train_data_nodes, aggregation_node, local_states, round_idx, clean_models
):
    """
    Compute the number of replicates for each combination of factors.

    Parameters
    ----------
    train_data_nodes: list
        List of TrainDataNode.

    aggregation_node: AggregationNode
        The aggregation node.

    local_states: dict
        Local states. Required to propagate intermediate results.

    round_idx: int
        Index of the current round.

    clean_models: bool
        Whether to clean the models after the computation.

    Returns
    -------
    local_states: dict
        Local states, to store the number of replicates and cell level codes.

    round_idx: int
        The updated round index.
    """
    local_states, shared_states, round_idx = local_step(
        local_method=self.loc_get_design_matrix_levels,
        train_data_nodes=train_data_nodes,
        output_local_states=local_states,
        input_local_states=local_states,
        input_shared_state=None,
        aggregation_id=aggregation_node.organization_id,
        description="Get local matrix design level",
        round_idx=round_idx,
        clean_models=clean_models,
    )
    counts_lvl_share_state, round_idx = aggregation_step(
        aggregation_method=self.agg_get_counts_lvl_for_cells,
        train_data_nodes=train_data_nodes,
        aggregation_node=aggregation_node,
        input_shared_states=shared_states,
        description="Compute counts level",
        round_idx=round_idx,
        clean_models=clean_models,
    )

    local_states, _, round_idx = local_step(
        local_method=self.loc_finalize_cell_counts,
        train_data_nodes=train_data_nodes,
        output_local_states=local_states,
        input_local_states=local_states,
        input_shared_state=counts_lvl_share_state,
        aggregation_id=aggregation_node.organization_id,
        description="Finalize cell counts",
        round_idx=round_idx,
        clean_models=clean_models,
    )

    return local_states, round_idx

substeps

AggGetCountsLvlForCells

Mixin that aggregate the counts of the design matrix values.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/substeps.py
class AggGetCountsLvlForCells:
    """Mixin that aggregate the counts of the design matrix values."""

    @remote
    @log_remote
    def agg_get_counts_lvl_for_cells(self, shared_states: list[dict]) -> dict:
        """
        Aggregate the counts of the design matrix values.

        Parameters
        ----------
        shared_states : list(dict)
            List of shared states with the following key:
            - unique_counts: unique values and counts of the local design matrix

        Returns
        -------
        dict
            Dictionary with keys labeling the different values taken by the
            overall design matrix. Each values of the dictionary contains the
            sum of the counts of the corresponding design matrix value and the level.
        """
        concat_unique_cont = pd.concat(
            [shared_state["unique_counts"] for shared_state in shared_states], axis=1
        )
        counts_by_lvl = concat_unique_cont.fillna(0).sum(axis=1).astype(int)

        return {"counts_by_lvl": counts_by_lvl}
agg_get_counts_lvl_for_cells(shared_states)

Aggregate the counts of the design matrix values.

Parameters:

Name Type Description Default
shared_states list(dict)

List of shared states with the following key: - unique_counts: unique values and counts of the local design matrix

required

Returns:

Type Description
dict

Dictionary with keys labeling the different values taken by the overall design matrix. Each values of the dictionary contains the sum of the counts of the corresponding design matrix value and the level.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/substeps.py
@remote
@log_remote
def agg_get_counts_lvl_for_cells(self, shared_states: list[dict]) -> dict:
    """
    Aggregate the counts of the design matrix values.

    Parameters
    ----------
    shared_states : list(dict)
        List of shared states with the following key:
        - unique_counts: unique values and counts of the local design matrix

    Returns
    -------
    dict
        Dictionary with keys labeling the different values taken by the
        overall design matrix. Each values of the dictionary contains the
        sum of the counts of the corresponding design matrix value and the level.
    """
    concat_unique_cont = pd.concat(
        [shared_state["unique_counts"] for shared_state in shared_states], axis=1
    )
    counts_by_lvl = concat_unique_cont.fillna(0).sum(axis=1).astype(int)

    return {"counts_by_lvl": counts_by_lvl}
LocFinalizeCellCounts

Mixin that finalize the cell counts.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/substeps.py
class LocFinalizeCellCounts:
    """Mixin that finalize the cell counts."""

    local_adata: AnnData

    @remote_data
    @log_remote_data
    @reconstruct_adatas
    def loc_finalize_cell_counts(self, data_from_opener, shared_state=dict) -> None:
        """
        Finalize the cell counts.

        Parameters
        ----------
        data_from_opener : ad.AnnData
            AnnData returned by the opener. Not used.

        shared_state : dict
            Dictionary with keys labeling the different values taken by the
            overall design matrix. Each values of the dictionary contains the
            sum of the counts of the corresponding design matrix value and the level.

        """
        counts_by_lvl = shared_state["counts_by_lvl"]

        # In order to keep the same objects 'num_replicates' and 'cells' used in
        # PyDESeq2, we provide names (0, 1, 2...) to the possible values of the
        # design matrix, called "lvl".
        # The index of 'num_replicates' is the lvl names (0,1,2...) and its values
        # the counts of these lvl
        # 'cells' index is the index of the cells in the adata and its values the lvl
        # name (0,1,2..) of the cell.
        self.local_adata.uns["num_replicates"] = pd.Series(counts_by_lvl.values)
        self.local_adata.obs["cells"] = [
            np.argwhere(counts_by_lvl.index == tuple(design))[0, 0]
            for design in self.local_adata.obsm["design_matrix"].values
        ]
loc_finalize_cell_counts(data_from_opener, shared_state=dict)

Finalize the cell counts.

Parameters:

Name Type Description Default
data_from_opener AnnData

AnnData returned by the opener. Not used.

required
shared_state dict

Dictionary with keys labeling the different values taken by the overall design matrix. Each values of the dictionary contains the sum of the counts of the corresponding design matrix value and the level.

dict
Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/substeps.py
@remote_data
@log_remote_data
@reconstruct_adatas
def loc_finalize_cell_counts(self, data_from_opener, shared_state=dict) -> None:
    """
    Finalize the cell counts.

    Parameters
    ----------
    data_from_opener : ad.AnnData
        AnnData returned by the opener. Not used.

    shared_state : dict
        Dictionary with keys labeling the different values taken by the
        overall design matrix. Each values of the dictionary contains the
        sum of the counts of the corresponding design matrix value and the level.

    """
    counts_by_lvl = shared_state["counts_by_lvl"]

    # In order to keep the same objects 'num_replicates' and 'cells' used in
    # PyDESeq2, we provide names (0, 1, 2...) to the possible values of the
    # design matrix, called "lvl".
    # The index of 'num_replicates' is the lvl names (0,1,2...) and its values
    # the counts of these lvl
    # 'cells' index is the index of the cells in the adata and its values the lvl
    # name (0,1,2..) of the cell.
    self.local_adata.uns["num_replicates"] = pd.Series(counts_by_lvl.values)
    self.local_adata.obs["cells"] = [
        np.argwhere(counts_by_lvl.index == tuple(design))[0, 0]
        for design in self.local_adata.obsm["design_matrix"].values
    ]
LocGetDesignMatrixLevels

Mixin to get the unique values of the local design matrix.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/substeps.py
class LocGetDesignMatrixLevels:
    """Mixin to get the unique values of the local design matrix."""

    local_adata: AnnData

    @remote_data
    @log_remote_data
    @reconstruct_adatas
    def loc_get_design_matrix_levels(self, data_from_opener, shared_state=dict) -> dict:
        """
        Get the values of the local design matrix.

        Parameters
        ----------
        data_from_opener : ad.AnnData
            AnnData returned by the opener. Not used.
        shared_state : dict
            Not used.

        Returns
        -------
        dict
            Dictionary with the following key:
            - unique_counts: unique values and counts of the local design matrix

        """
        unique_counts = self.local_adata.obsm["design_matrix"].value_counts()

        return {"unique_counts": unique_counts}
loc_get_design_matrix_levels(data_from_opener, shared_state=dict)

Get the values of the local design matrix.

Parameters:

Name Type Description Default
data_from_opener AnnData

AnnData returned by the opener. Not used.

required
shared_state dict

Not used.

dict

Returns:

Type Description
dict

Dictionary with the following key: - unique_counts: unique values and counts of the local design matrix

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/get_num_replicates/substeps.py
@remote_data
@log_remote_data
@reconstruct_adatas
def loc_get_design_matrix_levels(self, data_from_opener, shared_state=dict) -> dict:
    """
    Get the values of the local design matrix.

    Parameters
    ----------
    data_from_opener : ad.AnnData
        AnnData returned by the opener. Not used.
    shared_state : dict
        Not used.

    Returns
    -------
    dict
        Dictionary with the following key:
        - unique_counts: unique values and counts of the local design matrix

    """
    unique_counts = self.local_adata.obsm["design_matrix"].value_counts()

    return {"unique_counts": unique_counts}

substeps

LocLinMu

Mixin to fit linear mu estimates locally.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/substeps.py
class LocLinMu:
    """Mixin to fit linear mu estimates locally."""

    local_adata: AnnData
    refit_adata: AnnData
    min_mu: float
    max_disp: float

    @remote_data
    @log_remote_data
    @reconstruct_adatas
    def fit_lin_mu(
        self, data_from_opener, shared_state, min_mu=0.5, refit_mode: bool = False
    ):
        """Fit linear mu estimates and store them locally.

        Parameters
        ----------
        data_from_opener : ad.AnnData
            Not used.

        shared_state : dict
            Contains values to be saved in local adata:
            - "MoM_dispersions": MoM dispersions,
            - "nom_zero": Mask of all zero genes,
            - "tot_num_samples": Total number of samples.

        min_mu : float
            Lower threshold for fitted means, for numerical stability.
            (default: ``0.5``).

        refit_mode : bool
            Whether to run the pipeline in refit mode. If True, the pipeline will be run
            on `refit_adata`s instead of `local_adata`s. (default: ``False``).

        """
        if refit_mode:
            adata = self.refit_adata
        else:
            adata = self.local_adata

        # save MoM dispersions computed in the previous step
        adata.varm["_MoM_dispersions"] = shared_state["MoM_dispersions"]

        # save mask of all zero genes.
        # TODO: check that we should also do this in refit mode
        adata.varm["non_zero"] = shared_state["non_zero"]

        if not refit_mode:  # In refit mode, those are unchanged
            # save the total number of samples
            self.local_adata.uns["tot_num_samples"] = shared_state["tot_num_samples"]

            # use it to set max_disp
            self.local_adata.uns["max_disp"] = max(
                self.max_disp, self.local_adata.uns["tot_num_samples"]
            )

        # save the base_mean for independent filtering
        adata.varm["_normed_means"] = shared_state["tot_counts_mean"]

        # compute mu_hat
        set_fit_lin_mu_hat(adata, min_mu=min_mu)
fit_lin_mu(data_from_opener, shared_state, min_mu=0.5, refit_mode=False)

Fit linear mu estimates and store them locally.

Parameters:

Name Type Description Default
data_from_opener AnnData

Not used.

required
shared_state dict

Contains values to be saved in local adata: - "MoM_dispersions": MoM dispersions, - "nom_zero": Mask of all zero genes, - "tot_num_samples": Total number of samples.

required
min_mu float

Lower threshold for fitted means, for numerical stability. (default: 0.5).

0.5
refit_mode bool

Whether to run the pipeline in refit mode. If True, the pipeline will be run on refit_adatas instead of local_adatas. (default: False).

False
Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/substeps.py
@remote_data
@log_remote_data
@reconstruct_adatas
def fit_lin_mu(
    self, data_from_opener, shared_state, min_mu=0.5, refit_mode: bool = False
):
    """Fit linear mu estimates and store them locally.

    Parameters
    ----------
    data_from_opener : ad.AnnData
        Not used.

    shared_state : dict
        Contains values to be saved in local adata:
        - "MoM_dispersions": MoM dispersions,
        - "nom_zero": Mask of all zero genes,
        - "tot_num_samples": Total number of samples.

    min_mu : float
        Lower threshold for fitted means, for numerical stability.
        (default: ``0.5``).

    refit_mode : bool
        Whether to run the pipeline in refit mode. If True, the pipeline will be run
        on `refit_adata`s instead of `local_adata`s. (default: ``False``).

    """
    if refit_mode:
        adata = self.refit_adata
    else:
        adata = self.local_adata

    # save MoM dispersions computed in the previous step
    adata.varm["_MoM_dispersions"] = shared_state["MoM_dispersions"]

    # save mask of all zero genes.
    # TODO: check that we should also do this in refit mode
    adata.varm["non_zero"] = shared_state["non_zero"]

    if not refit_mode:  # In refit mode, those are unchanged
        # save the total number of samples
        self.local_adata.uns["tot_num_samples"] = shared_state["tot_num_samples"]

        # use it to set max_disp
        self.local_adata.uns["max_disp"] = max(
            self.max_disp, self.local_adata.uns["tot_num_samples"]
        )

    # save the base_mean for independent filtering
    adata.varm["_normed_means"] = shared_state["tot_counts_mean"]

    # compute mu_hat
    set_fit_lin_mu_hat(adata, min_mu=min_mu)

LocSetMuHat

Mixin to set mu estimates locally.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/substeps.py
class LocSetMuHat:
    """Mixin to set mu estimates locally."""

    local_adata: AnnData
    refit_adata: AnnData

    @remote_data
    @log_remote_data
    @reconstruct_adatas
    def set_mu_hat(
        self,
        data_from_opener,
        shared_state,
        refit_mode: bool = False,
    ) -> None:
        """Pick between linear and IRLS mu estimates.

        Parameters
        ----------
        data_from_opener : ad.AnnData
            Not used.

        shared_state : dict
            Not used.

        refit_mode : bool
            Whether to run on `refit_adata`s instead of `local_adata`s.
            (default: ``False``).
        """
        if refit_mode:
            adata = self.refit_adata
        else:
            adata = self.local_adata
        # TODO make sure that the adata has the num_replicates and the n_params
        set_mu_hat_layer(adata)
        del adata.layers["_fit_lin_mu_hat"]
        del adata.layers["_irls_mu_hat"]
set_mu_hat(data_from_opener, shared_state, refit_mode=False)

Pick between linear and IRLS mu estimates.

Parameters:

Name Type Description Default
data_from_opener AnnData

Not used.

required
shared_state dict

Not used.

required
refit_mode bool

Whether to run on refit_adatas instead of local_adatas. (default: False).

False
Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_genewise_dispersions/substeps.py
@remote_data
@log_remote_data
@reconstruct_adatas
def set_mu_hat(
    self,
    data_from_opener,
    shared_state,
    refit_mode: bool = False,
) -> None:
    """Pick between linear and IRLS mu estimates.

    Parameters
    ----------
    data_from_opener : ad.AnnData
        Not used.

    shared_state : dict
        Not used.

    refit_mode : bool
        Whether to run on `refit_adata`s instead of `local_adata`s.
        (default: ``False``).
    """
    if refit_mode:
        adata = self.refit_adata
    else:
        adata = self.local_adata
    # TODO make sure that the adata has the num_replicates and the n_params
    set_mu_hat_layer(adata)
    del adata.layers["_fit_lin_mu_hat"]
    del adata.layers["_irls_mu_hat"]

compute_lfc

Module which contains the Mixin in charge of fitting log fold changes.

compute_lfc

Module containing the ComputeLFC method.

ComputeLFC

Bases: LocGetGramMatrixAndLogFeatures, AggCreateBetaInit, LocSaveLFC, FedProxQuasiNewton, FedIRLS

Mixin class to implement the LFC computation algorithm.

The goal of this class is to implement the IRLS algorithm specifically applied to the negative binomial distribution, with fixed dispersion parameter, and in the case where it fails, to catch it with the FedProxQuasiNewton algorithm.

This class also initializes the beta parameters and computes the final hat matrix.

Methods:

Name Description
compute_lfc

The main method to compute the log fold changes by running the IRLS algorithm and catching it with the FedProxQuasiNewton algorithm.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/compute_lfc.py
class ComputeLFC(
    LocGetGramMatrixAndLogFeatures,
    AggCreateBetaInit,
    LocSaveLFC,
    FedProxQuasiNewton,
    FedIRLS,
):
    r"""Mixin class to implement the LFC computation algorithm.

    The goal of this class is to implement the IRLS algorithm specifically applied
    to the negative binomial distribution, with fixed dispersion parameter, and
    in the case where it fails, to catch it with the FedProxQuasiNewton algorithm.

    This class also initializes the beta parameters and computes the final hat matrix.

    Methods
    -------
    compute_lfc
        The main method to compute the log fold changes by
        running the IRLS algorithm and catching it with the
        FedProxQuasiNewton algorithm.


    """

    def compute_lfc(
        self,
        train_data_nodes: list,
        aggregation_node: AggregationNode,
        local_states: dict,
        round_idx: int,
        clean_models: bool = True,
        lfc_mode: Literal["lfc", "mu_init"] = "lfc",
        refit_mode: bool = False,
    ):
        """Compute the log fold changes.

        Parameters
        ----------
        train_data_nodes: list
            List of TrainDataNode.

        aggregation_node: AggregationNode
            The aggregation node.

        local_states: dict
            Local states. Required to propagate intermediate results.

        round_idx: int
            The current round.

        clean_models: bool
            If True, the models are cleaned.

        lfc_mode: Literal["lfc", "mu_init"]
            The mode of the IRLS algorithm ("lfc" or "mu_init").

        refit_mode: bool
            Whether to run the pipeline in refit mode, after cooks outliers were
            replaced. If True, the pipeline will be run on `refit_adata`s instead of
            `local_adata`s. (default: False).


        Returns
        -------
        local_states: dict
            Local states. Required to propagate intermediate results.

        round_idx: int
            The updated round index.

        """
        #### ---- Initialization ---- ####

        # ---- Compute initial local beta estimates ---- #

        local_states, local_beta_init_shared_states, round_idx = local_step(
            local_method=self.get_gram_matrix_and_log_features,
            train_data_nodes=train_data_nodes,
            output_local_states=local_states,
            round_idx=round_idx,
            input_local_states=local_states,
            input_shared_state=None,
            aggregation_id=aggregation_node.organization_id,
            description="Create local initialization beta.",
            clean_models=clean_models,
            method_params={
                "lfc_mode": lfc_mode,
                "refit_mode": refit_mode,
            },
        )

        # ---- Compute initial global beta estimates ---- #

        global_irls_summands_nlls_shared_state, round_idx = aggregation_step(
            aggregation_method=self.create_beta_init,
            train_data_nodes=train_data_nodes,
            aggregation_node=aggregation_node,
            input_shared_states=local_beta_init_shared_states,
            description="Create initialization beta paramater.",
            round_idx=round_idx,
            clean_models=clean_models,
        )

        #### ---- Run IRLS ---- #####
        (
            local_states,
            irls_result_shared_state,
            round_idx,
        ) = self.run_fed_irls(
            train_data_nodes=train_data_nodes,
            aggregation_node=aggregation_node,
            local_states=local_states,
            input_shared_state=global_irls_summands_nlls_shared_state,
            round_idx=round_idx,
            clean_models=clean_models,
            refit_mode=refit_mode,
        )

        #### ---- Catch with FedProxQuasiNewton ----####

        local_states, PQN_shared_state, round_idx = self.run_fed_PQN(
            train_data_nodes=train_data_nodes,
            aggregation_node=aggregation_node,
            local_states=local_states,
            PQN_shared_state=irls_result_shared_state,
            first_iteration_mode="irls_catch",
            round_idx=round_idx,
            clean_models=clean_models,
            refit_mode=refit_mode,
        )

        # ---- Compute final hat matrix summands ---- #

        (
            local_states,
            _,
            round_idx,
        ) = local_step(
            local_method=self.save_lfc_to_local,
            train_data_nodes=train_data_nodes,
            output_local_states=local_states,
            round_idx=round_idx,
            input_local_states=local_states,
            input_shared_state=PQN_shared_state,
            aggregation_id=aggregation_node.organization_id,
            description="Compute local hat matrix summands and last nll.",
            clean_models=clean_models,
            method_params={"refit_mode": refit_mode},
        )

        return local_states, round_idx
compute_lfc(train_data_nodes, aggregation_node, local_states, round_idx, clean_models=True, lfc_mode='lfc', refit_mode=False)

Compute the log fold changes.

Parameters:

Name Type Description Default
train_data_nodes list

List of TrainDataNode.

required
aggregation_node AggregationNode

The aggregation node.

required
local_states dict

Local states. Required to propagate intermediate results.

required
round_idx int

The current round.

required
clean_models bool

If True, the models are cleaned.

True
lfc_mode Literal['lfc', 'mu_init']

The mode of the IRLS algorithm ("lfc" or "mu_init").

'lfc'
refit_mode bool

Whether to run the pipeline in refit mode, after cooks outliers were replaced. If True, the pipeline will be run on refit_adatas instead of local_adatas. (default: False).

False

Returns:

Name Type Description
local_states dict

Local states. Required to propagate intermediate results.

round_idx int

The updated round index.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/compute_lfc.py
def compute_lfc(
    self,
    train_data_nodes: list,
    aggregation_node: AggregationNode,
    local_states: dict,
    round_idx: int,
    clean_models: bool = True,
    lfc_mode: Literal["lfc", "mu_init"] = "lfc",
    refit_mode: bool = False,
):
    """Compute the log fold changes.

    Parameters
    ----------
    train_data_nodes: list
        List of TrainDataNode.

    aggregation_node: AggregationNode
        The aggregation node.

    local_states: dict
        Local states. Required to propagate intermediate results.

    round_idx: int
        The current round.

    clean_models: bool
        If True, the models are cleaned.

    lfc_mode: Literal["lfc", "mu_init"]
        The mode of the IRLS algorithm ("lfc" or "mu_init").

    refit_mode: bool
        Whether to run the pipeline in refit mode, after cooks outliers were
        replaced. If True, the pipeline will be run on `refit_adata`s instead of
        `local_adata`s. (default: False).


    Returns
    -------
    local_states: dict
        Local states. Required to propagate intermediate results.

    round_idx: int
        The updated round index.

    """
    #### ---- Initialization ---- ####

    # ---- Compute initial local beta estimates ---- #

    local_states, local_beta_init_shared_states, round_idx = local_step(
        local_method=self.get_gram_matrix_and_log_features,
        train_data_nodes=train_data_nodes,
        output_local_states=local_states,
        round_idx=round_idx,
        input_local_states=local_states,
        input_shared_state=None,
        aggregation_id=aggregation_node.organization_id,
        description="Create local initialization beta.",
        clean_models=clean_models,
        method_params={
            "lfc_mode": lfc_mode,
            "refit_mode": refit_mode,
        },
    )

    # ---- Compute initial global beta estimates ---- #

    global_irls_summands_nlls_shared_state, round_idx = aggregation_step(
        aggregation_method=self.create_beta_init,
        train_data_nodes=train_data_nodes,
        aggregation_node=aggregation_node,
        input_shared_states=local_beta_init_shared_states,
        description="Create initialization beta paramater.",
        round_idx=round_idx,
        clean_models=clean_models,
    )

    #### ---- Run IRLS ---- #####
    (
        local_states,
        irls_result_shared_state,
        round_idx,
    ) = self.run_fed_irls(
        train_data_nodes=train_data_nodes,
        aggregation_node=aggregation_node,
        local_states=local_states,
        input_shared_state=global_irls_summands_nlls_shared_state,
        round_idx=round_idx,
        clean_models=clean_models,
        refit_mode=refit_mode,
    )

    #### ---- Catch with FedProxQuasiNewton ----####

    local_states, PQN_shared_state, round_idx = self.run_fed_PQN(
        train_data_nodes=train_data_nodes,
        aggregation_node=aggregation_node,
        local_states=local_states,
        PQN_shared_state=irls_result_shared_state,
        first_iteration_mode="irls_catch",
        round_idx=round_idx,
        clean_models=clean_models,
        refit_mode=refit_mode,
    )

    # ---- Compute final hat matrix summands ---- #

    (
        local_states,
        _,
        round_idx,
    ) = local_step(
        local_method=self.save_lfc_to_local,
        train_data_nodes=train_data_nodes,
        output_local_states=local_states,
        round_idx=round_idx,
        input_local_states=local_states,
        input_shared_state=PQN_shared_state,
        aggregation_id=aggregation_node.organization_id,
        description="Compute local hat matrix summands and last nll.",
        clean_models=clean_models,
        method_params={"refit_mode": refit_mode},
    )

    return local_states, round_idx

substeps

Module to implement the substeps for the fitting of log fold changes.

This module contains all these substeps as mixin classes.

AggCreateBetaInit

Mixin to create the beta init.

Methods:

Name Description
create_beta_init

A remote method. Creates the beta init (initialization value for the ComputeLFC algorithm) and returns the initialization state for the IRLS algorithm containing this initialization value and other necessary quantities.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/substeps.py
class AggCreateBetaInit:
    """Mixin to create the beta init.

    Methods
    -------
    create_beta_init
        A remote method. Creates the beta init (initialization value for the
        ComputeLFC algorithm) and returns the initialization state for the
        IRLS algorithm containing this initialization value and
        other necessary quantities.
    """

    @remote
    @log_remote
    def create_beta_init(self, shared_states: list[dict]) -> dict[str, Any]:
        """Create the beta init.

        It does so either by solving the least squares regression system if
        the gram matrix is full rank, or by aggregating the log means if the
        gram matrix is not full rank.

        Parameters
        ----------
        shared_states: list[dict]
            A list of dictionaries containing the following
            keys:
            - gram_full_rank: bool
                Whether the gram matrix is full rank.
            - n_non_zero_genes: int
                The number of non zero genes.
            - n_params: int
                The number of parameters.
            If the gram matrix is full rank, the state contains:
            -  local_log_features: ndarray
                The local log features, only if the gram matrix is full rank.
            - global_gram_matrix: ndarray
                The global gram matrix, only if the gram matrix is full rank.
            If the gram matrix is not full rank, the state contains:
            - normed_log_means: ndarray
                The normed log means, only if the gram matrix is not full rank.
            - n_obs: int
                The number of observations, only if the gram matrix is not full rank.


        Returns
        -------
        dict[str, Any]
            A dictionary containing all the necessary info to run IRLS.
            It contains the following fields:
            - beta: ndarray
                The initial beta, of shape (n_non_zero_genes, n_params).
            - irls_diverged_mask: ndarray
                A boolean mask indicating if fed avg should be used for a given gene
                (shape: (n_non_zero_genes,)). Is set to False initially, and will
                be set to True if the gene has diverged.
            - irls_mask: ndarray
                A boolean mask indicating if IRLS should be used for a given gene
                (shape: (n_non_zero_genes,)). Is set to True initially, and will be
                set to False if the gene has converged or diverged.
            - global_nll: ndarray
                The global_nll of the current beta from the previous beta, of shape
                (n_non_zero_genes,).
            - round_number_irls: int
                The current round number of the IRLS algorithm.


        """
        # Get the global quantities
        gram_full_rank = shared_states[0]["gram_full_rank"]
        n_non_zero_genes = shared_states[0]["n_non_zero_genes"]

        # Step 1: Get the beta init
        # Condition on whether or not the gram matrix is full rank
        if gram_full_rank:
            # Get global gram matrix
            global_gram_matrix = shared_states[0]["global_gram_matrix"]

            # Aggregate the feature vectors
            feature_vectors = sum(
                [state["local_log_features"] for state in shared_states]
            )

            # Solve the system
            beta_init = np.linalg.solve(global_gram_matrix, feature_vectors.T).T

        else:
            # Aggregate the log means
            tot_counts = sum([state["n_obs"] for state in shared_states])
            beta_init = (
                sum(
                    [
                        state["normed_log_means"] * state["n_obs"]
                        for state in shared_states
                    ]
                )
                / tot_counts
            )

        # Step 2: instantiate other necessary quantities
        irls_diverged_mask = np.full(n_non_zero_genes, False)
        irls_mask = np.full(n_non_zero_genes, True)
        global_nll = np.full(n_non_zero_genes, 1000.0)

        return {
            "beta": beta_init,
            "irls_diverged_mask": irls_diverged_mask,
            "irls_mask": irls_mask,
            "global_nll": global_nll,
            "round_number_irls": 0,
        }
create_beta_init(shared_states)

Create the beta init.

It does so either by solving the least squares regression system if the gram matrix is full rank, or by aggregating the log means if the gram matrix is not full rank.

Parameters:

Name Type Description Default
shared_states list[dict]

A list of dictionaries containing the following keys: - gram_full_rank: bool Whether the gram matrix is full rank. - n_non_zero_genes: int The number of non zero genes. - n_params: int The number of parameters. If the gram matrix is full rank, the state contains: - local_log_features: ndarray The local log features, only if the gram matrix is full rank. - global_gram_matrix: ndarray The global gram matrix, only if the gram matrix is full rank. If the gram matrix is not full rank, the state contains: - normed_log_means: ndarray The normed log means, only if the gram matrix is not full rank. - n_obs: int The number of observations, only if the gram matrix is not full rank.

required

Returns:

Type Description
dict[str, Any]

A dictionary containing all the necessary info to run IRLS. It contains the following fields: - beta: ndarray The initial beta, of shape (n_non_zero_genes, n_params). - irls_diverged_mask: ndarray A boolean mask indicating if fed avg should be used for a given gene (shape: (n_non_zero_genes,)). Is set to False initially, and will be set to True if the gene has diverged. - irls_mask: ndarray A boolean mask indicating if IRLS should be used for a given gene (shape: (n_non_zero_genes,)). Is set to True initially, and will be set to False if the gene has converged or diverged. - global_nll: ndarray The global_nll of the current beta from the previous beta, of shape (n_non_zero_genes,). - round_number_irls: int The current round number of the IRLS algorithm.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/substeps.py
@remote
@log_remote
def create_beta_init(self, shared_states: list[dict]) -> dict[str, Any]:
    """Create the beta init.

    It does so either by solving the least squares regression system if
    the gram matrix is full rank, or by aggregating the log means if the
    gram matrix is not full rank.

    Parameters
    ----------
    shared_states: list[dict]
        A list of dictionaries containing the following
        keys:
        - gram_full_rank: bool
            Whether the gram matrix is full rank.
        - n_non_zero_genes: int
            The number of non zero genes.
        - n_params: int
            The number of parameters.
        If the gram matrix is full rank, the state contains:
        -  local_log_features: ndarray
            The local log features, only if the gram matrix is full rank.
        - global_gram_matrix: ndarray
            The global gram matrix, only if the gram matrix is full rank.
        If the gram matrix is not full rank, the state contains:
        - normed_log_means: ndarray
            The normed log means, only if the gram matrix is not full rank.
        - n_obs: int
            The number of observations, only if the gram matrix is not full rank.


    Returns
    -------
    dict[str, Any]
        A dictionary containing all the necessary info to run IRLS.
        It contains the following fields:
        - beta: ndarray
            The initial beta, of shape (n_non_zero_genes, n_params).
        - irls_diverged_mask: ndarray
            A boolean mask indicating if fed avg should be used for a given gene
            (shape: (n_non_zero_genes,)). Is set to False initially, and will
            be set to True if the gene has diverged.
        - irls_mask: ndarray
            A boolean mask indicating if IRLS should be used for a given gene
            (shape: (n_non_zero_genes,)). Is set to True initially, and will be
            set to False if the gene has converged or diverged.
        - global_nll: ndarray
            The global_nll of the current beta from the previous beta, of shape
            (n_non_zero_genes,).
        - round_number_irls: int
            The current round number of the IRLS algorithm.


    """
    # Get the global quantities
    gram_full_rank = shared_states[0]["gram_full_rank"]
    n_non_zero_genes = shared_states[0]["n_non_zero_genes"]

    # Step 1: Get the beta init
    # Condition on whether or not the gram matrix is full rank
    if gram_full_rank:
        # Get global gram matrix
        global_gram_matrix = shared_states[0]["global_gram_matrix"]

        # Aggregate the feature vectors
        feature_vectors = sum(
            [state["local_log_features"] for state in shared_states]
        )

        # Solve the system
        beta_init = np.linalg.solve(global_gram_matrix, feature_vectors.T).T

    else:
        # Aggregate the log means
        tot_counts = sum([state["n_obs"] for state in shared_states])
        beta_init = (
            sum(
                [
                    state["normed_log_means"] * state["n_obs"]
                    for state in shared_states
                ]
            )
            / tot_counts
        )

    # Step 2: instantiate other necessary quantities
    irls_diverged_mask = np.full(n_non_zero_genes, False)
    irls_mask = np.full(n_non_zero_genes, True)
    global_nll = np.full(n_non_zero_genes, 1000.0)

    return {
        "beta": beta_init,
        "irls_diverged_mask": irls_diverged_mask,
        "irls_mask": irls_mask,
        "global_nll": global_nll,
        "round_number_irls": 0,
    }

LocGetGramMatrixAndLogFeatures

Mixin accessing the quantities to compute the initial beta of ComputeLFC.

Attributes:

Name Type Description
local_adata AnnData

The local AnnData object.

Methods:

Name Description
get_gram_matrix_and_log_features

A remote_data method. Creates the local quantities necessary to compute the initial beta. If the gram matrix is full rank, it shares the features vector and the gram matrix. If the gram matrix is not full rank, it shares the normed log means and the number of observations.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/substeps.py
class LocGetGramMatrixAndLogFeatures:
    """Mixin accessing the quantities to compute the initial beta of ComputeLFC.

    Attributes
    ----------
    local_adata : AnnData
        The local AnnData object.

    Methods
    -------
    get_gram_matrix_and_log_features
        A remote_data method. Creates the local quantities necessary
        to compute the initial beta.
        If the gram matrix is full rank, it shares the features vector
        and the gram matrix. If the gram matrix is not full rank, it shares
        the normed log means and the number of observations.

    """

    local_adata: AnnData
    refit_adata: AnnData

    @remote_data
    @log_remote_data
    @reconstruct_adatas
    def get_gram_matrix_and_log_features(
        self,
        data_from_opener: AnnData,
        shared_state: dict[str, Any],
        lfc_mode: Literal["lfc", "mu_init"],
        refit_mode: bool = False,
    ):
        """Create the local quantities necessary to compute the initial beta.

        To do so, we assume that the local_adata.uns contains the following fields:
        - n_params: int
            The number of parameters.
        - _global_gram_matrix: ndarray
            The global gram matrix.

        From the IRLS mode, we will set the following fields:
        - _irls_mu_param_name: str
            The name of the mu parameter, to save at the end of the IRLS run
            This is None if we do not want to save the mu parameter.
        - _irls_beta_param_name: str
            The name of the beta parameter, to save as a varm at the end of the
            fed irls run
            This is None if we do not want to save the beta parameter.
        - _irls_disp_param_name: str
            The name of the dispersion parameter.
        - _lfc_mode: str
            The mode of the IRLS algorithm. This is used to set the previous fields.

        Parameters
        ----------
        data_from_opener : AnnData
            Not used.

        shared_state : dict
            Not used, all the necessary info is stored in the local adata.

        lfc_mode : Literal["lfc", "mu_init"]
            The mode of the IRLS algorithm ("lfc", or "mu_init").

        refit_mode : bool
            Whether to run the pipeline on `refit_adata` instead of `local_adata`.

        Returns
        -------
        dict
            The state to share to the server.
            It always contains the following fields:
            - gram_full_rank: bool
                Whether the gram matrix is full rank.
            - n_non_zero_genes: int
                The number of non zero genes.
            - n_params: int
                The number of parameters.
            - If the gram matrix is full rank, the state contains:
                - local_log_features: ndarray
                    The local log features.
                - global_gram_matrix: ndarray
                    The global gram matrix.
            - If the gram matrix is not full rank, the state contains:
                - normed_log_means: ndarray
                    The normed log means.
                - n_obs: int
                    The number of observations.

        """
        global_gram_matrix = self.local_adata.uns["_global_gram_matrix"]

        if refit_mode:
            adata = self.refit_adata
        else:
            adata = self.local_adata

        # Elements to pass on to the next steps of the method
        if lfc_mode == "lfc":
            adata.uns["_irls_mu_param_name"] = "_mu_LFC"
            adata.uns["_irls_beta_param_name"] = "LFC"
            adata.uns["_irls_disp_param_name"] = "dispersions"
            adata.uns["_lfc_mode"] = "lfc"
        elif lfc_mode == "mu_init":
            adata.uns["_irls_mu_param_name"] = "_irls_mu_hat"
            adata.uns["_irls_beta_param_name"] = "_mu_hat_LFC"
            adata.uns["_irls_disp_param_name"] = "_MoM_dispersions"
            adata.uns["_lfc_mode"] = "mu_init"

        else:
            raise NotImplementedError(
                f"Only 'lfc' and 'mu_init' irls modes are supported, got {lfc_mode}."
            )

        # Get non zero genes
        non_zero_genes_names = adata.var_names[adata.varm["non_zero"]]

        # See if gram matrix is full rank
        gram_full_rank = (
            np.linalg.matrix_rank(global_gram_matrix) == adata.uns["n_params"]
        )
        # If the gram matrix is full rank, share the features vector and the gram
        # matrix

        shared_state = {
            "gram_full_rank": gram_full_rank,
            "n_non_zero_genes": len(non_zero_genes_names),
        }

        if gram_full_rank:
            # Make log features
            design = adata.obsm["design_matrix"].values
            log_counts = np.log(
                adata[:, non_zero_genes_names].layers["normed_counts"] + 0.1
            )
            log_features = (design.T @ log_counts).T
            shared_state.update(
                {
                    "local_log_features": log_features,
                    "global_gram_matrix": global_gram_matrix,
                }
            )
        else:
            # TODO: check that this is correctly recomputed in refit mode
            if "normed_log_means" not in adata.varm:
                with np.errstate(divide="ignore"):  # ignore division by zero warnings
                    log_counts = np.log(adata.layers["normed_counts"])
                    adata.varm["normed_log_means"] = log_counts.mean(0)
            normed_log_means = adata.varm["normed_log_means"]
            n_obs = adata.n_obs
            shared_state.update({"normed_log_means": normed_log_means, "n_obs": n_obs})
        return shared_state
get_gram_matrix_and_log_features(data_from_opener, shared_state, lfc_mode, refit_mode=False)

Create the local quantities necessary to compute the initial beta.

To do so, we assume that the local_adata.uns contains the following fields: - n_params: int The number of parameters. - _global_gram_matrix: ndarray The global gram matrix.

From the IRLS mode, we will set the following fields: - _irls_mu_param_name: str The name of the mu parameter, to save at the end of the IRLS run This is None if we do not want to save the mu parameter. - _irls_beta_param_name: str The name of the beta parameter, to save as a varm at the end of the fed irls run This is None if we do not want to save the beta parameter. - _irls_disp_param_name: str The name of the dispersion parameter. - _lfc_mode: str The mode of the IRLS algorithm. This is used to set the previous fields.

Parameters:

Name Type Description Default
data_from_opener AnnData

Not used.

required
shared_state dict

Not used, all the necessary info is stored in the local adata.

required
lfc_mode Literal['lfc', 'mu_init']

The mode of the IRLS algorithm ("lfc", or "mu_init").

required
refit_mode bool

Whether to run the pipeline on refit_adata instead of local_adata.

False

Returns:

Type Description
dict

The state to share to the server. It always contains the following fields: - gram_full_rank: bool Whether the gram matrix is full rank. - n_non_zero_genes: int The number of non zero genes. - n_params: int The number of parameters. - If the gram matrix is full rank, the state contains: - local_log_features: ndarray The local log features. - global_gram_matrix: ndarray The global gram matrix. - If the gram matrix is not full rank, the state contains: - normed_log_means: ndarray The normed log means. - n_obs: int The number of observations.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/substeps.py
@remote_data
@log_remote_data
@reconstruct_adatas
def get_gram_matrix_and_log_features(
    self,
    data_from_opener: AnnData,
    shared_state: dict[str, Any],
    lfc_mode: Literal["lfc", "mu_init"],
    refit_mode: bool = False,
):
    """Create the local quantities necessary to compute the initial beta.

    To do so, we assume that the local_adata.uns contains the following fields:
    - n_params: int
        The number of parameters.
    - _global_gram_matrix: ndarray
        The global gram matrix.

    From the IRLS mode, we will set the following fields:
    - _irls_mu_param_name: str
        The name of the mu parameter, to save at the end of the IRLS run
        This is None if we do not want to save the mu parameter.
    - _irls_beta_param_name: str
        The name of the beta parameter, to save as a varm at the end of the
        fed irls run
        This is None if we do not want to save the beta parameter.
    - _irls_disp_param_name: str
        The name of the dispersion parameter.
    - _lfc_mode: str
        The mode of the IRLS algorithm. This is used to set the previous fields.

    Parameters
    ----------
    data_from_opener : AnnData
        Not used.

    shared_state : dict
        Not used, all the necessary info is stored in the local adata.

    lfc_mode : Literal["lfc", "mu_init"]
        The mode of the IRLS algorithm ("lfc", or "mu_init").

    refit_mode : bool
        Whether to run the pipeline on `refit_adata` instead of `local_adata`.

    Returns
    -------
    dict
        The state to share to the server.
        It always contains the following fields:
        - gram_full_rank: bool
            Whether the gram matrix is full rank.
        - n_non_zero_genes: int
            The number of non zero genes.
        - n_params: int
            The number of parameters.
        - If the gram matrix is full rank, the state contains:
            - local_log_features: ndarray
                The local log features.
            - global_gram_matrix: ndarray
                The global gram matrix.
        - If the gram matrix is not full rank, the state contains:
            - normed_log_means: ndarray
                The normed log means.
            - n_obs: int
                The number of observations.

    """
    global_gram_matrix = self.local_adata.uns["_global_gram_matrix"]

    if refit_mode:
        adata = self.refit_adata
    else:
        adata = self.local_adata

    # Elements to pass on to the next steps of the method
    if lfc_mode == "lfc":
        adata.uns["_irls_mu_param_name"] = "_mu_LFC"
        adata.uns["_irls_beta_param_name"] = "LFC"
        adata.uns["_irls_disp_param_name"] = "dispersions"
        adata.uns["_lfc_mode"] = "lfc"
    elif lfc_mode == "mu_init":
        adata.uns["_irls_mu_param_name"] = "_irls_mu_hat"
        adata.uns["_irls_beta_param_name"] = "_mu_hat_LFC"
        adata.uns["_irls_disp_param_name"] = "_MoM_dispersions"
        adata.uns["_lfc_mode"] = "mu_init"

    else:
        raise NotImplementedError(
            f"Only 'lfc' and 'mu_init' irls modes are supported, got {lfc_mode}."
        )

    # Get non zero genes
    non_zero_genes_names = adata.var_names[adata.varm["non_zero"]]

    # See if gram matrix is full rank
    gram_full_rank = (
        np.linalg.matrix_rank(global_gram_matrix) == adata.uns["n_params"]
    )
    # If the gram matrix is full rank, share the features vector and the gram
    # matrix

    shared_state = {
        "gram_full_rank": gram_full_rank,
        "n_non_zero_genes": len(non_zero_genes_names),
    }

    if gram_full_rank:
        # Make log features
        design = adata.obsm["design_matrix"].values
        log_counts = np.log(
            adata[:, non_zero_genes_names].layers["normed_counts"] + 0.1
        )
        log_features = (design.T @ log_counts).T
        shared_state.update(
            {
                "local_log_features": log_features,
                "global_gram_matrix": global_gram_matrix,
            }
        )
    else:
        # TODO: check that this is correctly recomputed in refit mode
        if "normed_log_means" not in adata.varm:
            with np.errstate(divide="ignore"):  # ignore division by zero warnings
                log_counts = np.log(adata.layers["normed_counts"])
                adata.varm["normed_log_means"] = log_counts.mean(0)
        normed_log_means = adata.varm["normed_log_means"]
        n_obs = adata.n_obs
        shared_state.update({"normed_log_means": normed_log_means, "n_obs": n_obs})
    return shared_state

LocSaveLFC

Mixin to create the local quantities to compute the final hat matrix.

Attributes:

Name Type Description
local_adata AnnData

The local AnnData object.

num_jobs int

The number of cpus to use.

joblib_verbosity int

The verbosity of the joblib backend.

joblib_backend str

The backend to use for the joblib parallelization.

irls_batch_size int

The batch size to use for the IRLS algorithm.

min_mu float

The minimum value for the mu parameter.

Methods:

Name Description
make_local_final_hat_matrix_summands

A remote_data method. Creates the local quantities to compute the final hat matrix, which must be computed on all genes. This step is expected to be applied after catching the IRLS method with the fed prox quasi newton method, and takes as an input a shared state from the last iteration of that method.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/substeps.py
class LocSaveLFC:
    """Mixin to create the local quantities to compute the final hat matrix.

    Attributes
    ----------
    local_adata : AnnData
        The local AnnData object.
    num_jobs : int
        The number of cpus to use.
    joblib_verbosity : int
        The verbosity of the joblib backend.
    joblib_backend : str
        The backend to use for the joblib parallelization.
    irls_batch_size : int
        The batch size to use for the IRLS algorithm.
    min_mu : float
        The minimum value for the mu parameter.

    Methods
    -------
    make_local_final_hat_matrix_summands
        A remote_data method. Creates the local quantities to compute the
        final hat matrix, which must be computed on all genes. This step
        is expected to be applied after catching the IRLS method
        with the fed prox quasi newton method, and takes as an input a
        shared state from the last iteration of that method.

    """

    local_adata: AnnData
    refit_adata: AnnData
    num_jobs: int
    joblib_verbosity: int
    joblib_backend: str
    irls_batch_size: int
    min_mu: float
    irls_num_iter: int

    @remote_data
    @log_remote_data
    @reconstruct_adatas
    def save_lfc_to_local(
        self,
        data_from_opener: AnnData,
        shared_state: dict[str, Any],
        refit_mode: bool = False,
    ):
        """Create the local quantities to compute the final hat matrix.

        Parameters
        ----------
        data_from_opener : AnnData
            Not used.

        shared_state : dict
            The shared state.
            The shared state is a dictionary containing the following
            keys:
            - beta: ndarray
                The current beta, of shape (n_non_zero_genes, n_params).
            - irls_diverged_mask: ndarray
                A boolean mask indicating if the irsl method has diverged.
                In that case, these genes are caught with the fed prox newton
                method.
                (shape: (n_non_zero_genes,)).
            - PQN_diverged_mask: ndarray
                A boolean mask indicating if the fed prox newton method has
                diverged. These genes are not caught by any method, and the
                returned beta value is the output of the PQN method, even
                though it has not converged.

        refit_mode : bool
            Whether to run the pipeline on `refit_adata` instead of `local_adata`.
            (default: False).

        """
        beta = shared_state["beta"]

        if refit_mode:
            adata = self.refit_adata
        else:
            adata = self.local_adata

        # TODO keeping this in memory for now, see if need for removal at the end
        adata.uns["_irls_diverged_mask"] = shared_state["irls_diverged_mask"]
        adata.uns["_PQN_diverged_mask"] = shared_state["PQN_diverged_mask"]

        # Get the param names stored in the local adata
        mu_param_name = adata.uns["_irls_mu_param_name"]
        beta_param_name = adata.uns["_irls_beta_param_name"]
        # ---- Step 2: Store the mu, the diagonal of the hat matrix  ---- #
        # ----           and beta in the adata                       ---- #

        design_column_names = adata.obsm["design_matrix"].columns

        non_zero_genes_names = adata.var_names[adata.varm["non_zero"]]

        beta_dataframe = pd.DataFrame(
            np.NaN, index=adata.var_names, columns=design_column_names
        )
        beta_dataframe.loc[non_zero_genes_names, :] = beta

        adata.varm[beta_param_name] = beta_dataframe

        if mu_param_name is not None:
            set_mu_layer(
                local_adata=adata,
                lfc_param_name=beta_param_name,
                mu_param_name=mu_param_name,
                n_jobs=self.num_jobs,
                joblib_verbosity=self.joblib_verbosity,
                joblib_backend=self.joblib_backend,
                batch_size=self.irls_batch_size,
            )
save_lfc_to_local(data_from_opener, shared_state, refit_mode=False)

Create the local quantities to compute the final hat matrix.

Parameters:

Name Type Description Default
data_from_opener AnnData

Not used.

required
shared_state dict

The shared state. The shared state is a dictionary containing the following keys: - beta: ndarray The current beta, of shape (n_non_zero_genes, n_params). - irls_diverged_mask: ndarray A boolean mask indicating if the irsl method has diverged. In that case, these genes are caught with the fed prox newton method. (shape: (n_non_zero_genes,)). - PQN_diverged_mask: ndarray A boolean mask indicating if the fed prox newton method has diverged. These genes are not caught by any method, and the returned beta value is the output of the PQN method, even though it has not converged.

required
refit_mode bool

Whether to run the pipeline on refit_adata instead of local_adata. (default: False).

False
Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/substeps.py
@remote_data
@log_remote_data
@reconstruct_adatas
def save_lfc_to_local(
    self,
    data_from_opener: AnnData,
    shared_state: dict[str, Any],
    refit_mode: bool = False,
):
    """Create the local quantities to compute the final hat matrix.

    Parameters
    ----------
    data_from_opener : AnnData
        Not used.

    shared_state : dict
        The shared state.
        The shared state is a dictionary containing the following
        keys:
        - beta: ndarray
            The current beta, of shape (n_non_zero_genes, n_params).
        - irls_diverged_mask: ndarray
            A boolean mask indicating if the irsl method has diverged.
            In that case, these genes are caught with the fed prox newton
            method.
            (shape: (n_non_zero_genes,)).
        - PQN_diverged_mask: ndarray
            A boolean mask indicating if the fed prox newton method has
            diverged. These genes are not caught by any method, and the
            returned beta value is the output of the PQN method, even
            though it has not converged.

    refit_mode : bool
        Whether to run the pipeline on `refit_adata` instead of `local_adata`.
        (default: False).

    """
    beta = shared_state["beta"]

    if refit_mode:
        adata = self.refit_adata
    else:
        adata = self.local_adata

    # TODO keeping this in memory for now, see if need for removal at the end
    adata.uns["_irls_diverged_mask"] = shared_state["irls_diverged_mask"]
    adata.uns["_PQN_diverged_mask"] = shared_state["PQN_diverged_mask"]

    # Get the param names stored in the local adata
    mu_param_name = adata.uns["_irls_mu_param_name"]
    beta_param_name = adata.uns["_irls_beta_param_name"]
    # ---- Step 2: Store the mu, the diagonal of the hat matrix  ---- #
    # ----           and beta in the adata                       ---- #

    design_column_names = adata.obsm["design_matrix"].columns

    non_zero_genes_names = adata.var_names[adata.varm["non_zero"]]

    beta_dataframe = pd.DataFrame(
        np.NaN, index=adata.var_names, columns=design_column_names
    )
    beta_dataframe.loc[non_zero_genes_names, :] = beta

    adata.varm[beta_param_name] = beta_dataframe

    if mu_param_name is not None:
        set_mu_layer(
            local_adata=adata,
            lfc_param_name=beta_param_name,
            mu_param_name=mu_param_name,
            n_jobs=self.num_jobs,
            joblib_verbosity=self.joblib_verbosity,
            joblib_backend=self.joblib_backend,
            batch_size=self.irls_batch_size,
        )

utils

Module to implement the utilities of the IRLS algorithm.

Most of these functions have the _batch suffix, which means that they are vectorized to work over batches of genes in the parralel_backend file in the same module.

make_irls_nll_batch(beta, design_matrix, size_factors, dispersions, counts, min_mu=0.5)

Compute the negative binomial log likelihood from LFC estimates.

Used in ComputeLFC to compute the deviance score. This function is vectorized to work over batches of genes.

Parameters:

Name Type Description Default
beta ndarray

Current LFC estimate, of shape (batch_size, n_params).

required
design_matrix ndarray

The design matrix, of shape (n_obs, n_params).

required
size_factors ndarray

The size factors, of shape (n_obs).

required
dispersions ndarray

The dispersions, of shape (batch_size).

required
counts ndarray

The counts, of shape (n_obs,batch_size).

required
min_mu float

Lower bound on estimated means, to ensure numerical stability. (default: 0.5).

0.5

Returns:

Type Description
ndarray

Local negative binomial log-likelihoods, of shape (batch_size).

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/compute_lfc/utils.py
def make_irls_nll_batch(
    beta: np.ndarray,
    design_matrix: np.ndarray,
    size_factors: np.ndarray,
    dispersions: np.ndarray,
    counts: np.ndarray,
    min_mu: float = 0.5,
) -> np.ndarray:
    """
    Compute the negative binomial log likelihood from LFC estimates.

    Used in ComputeLFC to compute the deviance score. This function is vectorized to
    work over batches of genes.

    Parameters
    ----------
    beta : np.ndarray
        Current LFC estimate, of shape (batch_size, n_params).
    design_matrix : np.ndarray
        The design matrix, of shape (n_obs, n_params).
    size_factors : np.ndarray
        The size factors, of shape (n_obs).
    dispersions : np.ndarray
        The dispersions, of shape (batch_size).
    counts : np.ndarray
        The counts, of shape (n_obs,batch_size).
    min_mu : float
        Lower bound on estimated means, to ensure numerical stability.
        (default: ``0.5``).

    Returns
    -------
    np.ndarray
        Local negative binomial log-likelihoods, of shape
        (batch_size).
    """
    mu = np.maximum(
        size_factors[:, None] * np.exp(design_matrix @ beta.T),
        min_mu,
    )
    return grid_nb_nll(
        counts,
        mu,
        dispersions,
    )

deseq2_lfc_dispersions

DESeq2LFCDispersions

Bases: ComputeGenewiseDispersions, ComputeDispersionPrior, ComputeMAPDispersions, ComputeLFC

Mixin class to compute the log fold change and the dispersions with DESeq2.

This class encapsulates the steps to compute the log fold change and the dispersions from a given count matrix and a design matrix.

Methods:

Name Description
run_deseq2_lfc_dispersions

The method to compute the log fold change and the dispersions. It starts from the design matrix and the count matrix. It returns the shared states by the local nodes after the computation of Cook's distances. It is meant to be run two times in the main pipeline if Cook's refitting is applied/

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/deseq2_lfc_dispersions.py
class DESeq2LFCDispersions(
    ComputeGenewiseDispersions,
    ComputeDispersionPrior,
    ComputeMAPDispersions,
    ComputeLFC,
):
    """Mixin class to compute the log fold change and the dispersions with DESeq2.

    This class encapsulates the steps to compute the log fold change and the
    dispersions from a given count matrix and a design matrix.

    Methods
    -------
    run_deseq2_lfc_dispersions
        The method to compute the log fold change and the dispersions.
        It starts from the design matrix and the count matrix.
        It returns the shared states by the local nodes after the computation of Cook's
        distances.
        It is meant to be run two times in the main pipeline if Cook's refitting
        is applied/
    """

    def run_deseq2_lfc_dispersions(
        self,
        train_data_nodes,
        aggregation_node,
        local_states,
        gram_features_shared_states,
        round_idx,
        clean_models,
        refit_mode=False,
    ):
        """
        Run the DESeq2 pipeline to compute the log fold change and the dispersions.

        Parameters
        ----------
        train_data_nodes: list
            List of TrainDataNode.

        aggregation_node: AggregationNode
            The aggregation node.

        local_states: list[dict]
            Local states. Required to propagate intermediate results.

        gram_features_shared_states: list[dict]
            Output of the "compute_size_factor step" if refit_mode is False.
            Output of the "replace_outliers" step if refit_mode is True.
            In both cases, contains a "local_features" key with the features vector
            to input to compute_genewise_dispersion.
            In the non refit mode case, it also contains a "local_gram_matrix" key
             with the local gram matrix.

        round_idx: int
            Index of the current round.

        clean_models: bool
            Whether to clean the models after the computation.

        refit_mode: bool
            Whether we are refittinh Cooks outliers or not.


        Returns
        -------
        local_states: dict
            Local states updated with the results of the DESeq2 pipeline.

        round_idx: int
            The updated round index.

        """
        #### Fit genewise dispersions ####

        # Note : for optimization purposes, we could avoid two successive local
        # steps here, at the cost of a more complex initialization of the
        # fit_dispersions method.
        logger.info("Fit genewise dispersions...")
        (
            local_states,
            genewise_dispersions_shared_state,
            round_idx,
        ) = self.fit_genewise_dispersions(
            train_data_nodes=train_data_nodes,
            aggregation_node=aggregation_node,
            gram_features_shared_states=gram_features_shared_states,
            local_states=local_states,
            round_idx=round_idx,
            clean_models=clean_models,
            refit_mode=refit_mode,
        )
        logger.info("Finished fitting genewise dispersions.")

        if not refit_mode:
            #### Fit dispersion trends ####
            logger.info("Compute dispersion prior...")
            (
                local_states,
                dispersion_trend_share_state,
                round_idx,
            ) = self.compute_dispersion_prior(
                train_data_nodes,
                aggregation_node,
                local_states,
                genewise_dispersions_shared_state,
                round_idx,
                clean_models,
            )
            logger.info("Finished computing dispersion prior.")
        else:
            # Just update the fitted dispersions
            (
                local_states,
                dispersion_trend_share_state,
                round_idx,
            ) = local_step(
                local_method=self.loc_update_fitted_dispersions,
                train_data_nodes=train_data_nodes,
                output_local_states=local_states,
                round_idx=round_idx,
                input_local_states=local_states,
                input_shared_state=genewise_dispersions_shared_state,
                aggregation_id=aggregation_node.organization_id,
                description="Update fitted dispersions",
                clean_models=clean_models,
            )

        #### Fit MAP dispersions ####
        logger.info("Fit MAP dispersions...")
        (
            local_states,
            round_idx,
        ) = self.fit_MAP_dispersions(
            train_data_nodes=train_data_nodes,
            aggregation_node=aggregation_node,
            local_states=local_states,
            shared_state=dispersion_trend_share_state if not refit_mode else None,
            round_idx=round_idx,
            clean_models=clean_models,
            refit_mode=refit_mode,
        )
        logger.info("Finished fitting MAP dispersions.")

        #### Compute log fold changes ####
        logger.info("Compute log fold changes...")
        local_states, round_idx = self.compute_lfc(
            train_data_nodes,
            aggregation_node,
            local_states,
            round_idx,
            clean_models=True,
            lfc_mode="lfc",
            refit_mode=refit_mode,
        )
        logger.info("Finished computing log fold changes.")

        return local_states, round_idx

run_deseq2_lfc_dispersions(train_data_nodes, aggregation_node, local_states, gram_features_shared_states, round_idx, clean_models, refit_mode=False)

Run the DESeq2 pipeline to compute the log fold change and the dispersions.

Parameters:

Name Type Description Default
train_data_nodes

List of TrainDataNode.

required
aggregation_node

The aggregation node.

required
local_states

Local states. Required to propagate intermediate results.

required
gram_features_shared_states

Output of the "compute_size_factor step" if refit_mode is False. Output of the "replace_outliers" step if refit_mode is True. In both cases, contains a "local_features" key with the features vector to input to compute_genewise_dispersion. In the non refit mode case, it also contains a "local_gram_matrix" key with the local gram matrix.

required
round_idx

Index of the current round.

required
clean_models

Whether to clean the models after the computation.

required
refit_mode

Whether we are refittinh Cooks outliers or not.

False

Returns:

Name Type Description
local_states dict

Local states updated with the results of the DESeq2 pipeline.

round_idx int

The updated round index.

Source code in fedpydeseq2/core/deseq2_core/deseq2_lfc_dispersions/deseq2_lfc_dispersions.py
def run_deseq2_lfc_dispersions(
    self,
    train_data_nodes,
    aggregation_node,
    local_states,
    gram_features_shared_states,
    round_idx,
    clean_models,
    refit_mode=False,
):
    """
    Run the DESeq2 pipeline to compute the log fold change and the dispersions.

    Parameters
    ----------
    train_data_nodes: list
        List of TrainDataNode.

    aggregation_node: AggregationNode
        The aggregation node.

    local_states: list[dict]
        Local states. Required to propagate intermediate results.

    gram_features_shared_states: list[dict]
        Output of the "compute_size_factor step" if refit_mode is False.
        Output of the "replace_outliers" step if refit_mode is True.
        In both cases, contains a "local_features" key with the features vector
        to input to compute_genewise_dispersion.
        In the non refit mode case, it also contains a "local_gram_matrix" key
         with the local gram matrix.

    round_idx: int
        Index of the current round.

    clean_models: bool
        Whether to clean the models after the computation.

    refit_mode: bool
        Whether we are refittinh Cooks outliers or not.


    Returns
    -------
    local_states: dict
        Local states updated with the results of the DESeq2 pipeline.

    round_idx: int
        The updated round index.

    """
    #### Fit genewise dispersions ####

    # Note : for optimization purposes, we could avoid two successive local
    # steps here, at the cost of a more complex initialization of the
    # fit_dispersions method.
    logger.info("Fit genewise dispersions...")
    (
        local_states,
        genewise_dispersions_shared_state,
        round_idx,
    ) = self.fit_genewise_dispersions(
        train_data_nodes=train_data_nodes,
        aggregation_node=aggregation_node,
        gram_features_shared_states=gram_features_shared_states,
        local_states=local_states,
        round_idx=round_idx,
        clean_models=clean_models,
        refit_mode=refit_mode,
    )
    logger.info("Finished fitting genewise dispersions.")

    if not refit_mode:
        #### Fit dispersion trends ####
        logger.info("Compute dispersion prior...")
        (
            local_states,
            dispersion_trend_share_state,
            round_idx,
        ) = self.compute_dispersion_prior(
            train_data_nodes,
            aggregation_node,
            local_states,
            genewise_dispersions_shared_state,
            round_idx,
            clean_models,
        )
        logger.info("Finished computing dispersion prior.")
    else:
        # Just update the fitted dispersions
        (
            local_states,
            dispersion_trend_share_state,
            round_idx,
        ) = local_step(
            local_method=self.loc_update_fitted_dispersions,
            train_data_nodes=train_data_nodes,
            output_local_states=local_states,
            round_idx=round_idx,
            input_local_states=local_states,
            input_shared_state=genewise_dispersions_shared_state,
            aggregation_id=aggregation_node.organization_id,
            description="Update fitted dispersions",
            clean_models=clean_models,
        )

    #### Fit MAP dispersions ####
    logger.info("Fit MAP dispersions...")
    (
        local_states,
        round_idx,
    ) = self.fit_MAP_dispersions(
        train_data_nodes=train_data_nodes,
        aggregation_node=aggregation_node,
        local_states=local_states,
        shared_state=dispersion_trend_share_state if not refit_mode else None,
        round_idx=round_idx,
        clean_models=clean_models,
        refit_mode=refit_mode,
    )
    logger.info("Finished fitting MAP dispersions.")

    #### Compute log fold changes ####
    logger.info("Compute log fold changes...")
    local_states, round_idx = self.compute_lfc(
        train_data_nodes,
        aggregation_node,
        local_states,
        round_idx,
        clean_models=True,
        lfc_mode="lfc",
        refit_mode=refit_mode,
    )
    logger.info("Finished computing log fold changes.")

    return local_states, round_idx