Skip to content

Computing the Cooks distance

Pipe step computing the cooks distance.

compute_cook_distance

ComputeCookDistances

Bases: ComputeTrimmedMean, LocComputeSqerror, LocGetNormedCounts, AggComputeDispersionForCook

Mixin class to compute Cook's distances.

Methods:

Name Description
compute_cook_distance

The method to compute Cook's distances.

Source code in fedpydeseq2/core/deseq2_core/compute_cook_distance/compute_cook_distance.py
class ComputeCookDistances(
    ComputeTrimmedMean,
    LocComputeSqerror,
    LocGetNormedCounts,
    AggComputeDispersionForCook,
):
    """Mixin class to compute Cook's distances.

    Methods
    -------
    compute_cook_distance
        The method to compute Cook's distances.
    """

    trimmed_mean_num_iter: int

    @log_organisation_method
    def compute_cook_distance(
        self,
        train_data_nodes,
        aggregation_node,
        local_states,
        round_idx,
        clean_models,
    ):
        """Compute Cook's distances.

        Parameters
        ----------
        train_data_nodes: list
            list of TrainDataNode.

        aggregation_node: AggregationNode
            The aggregation node.

        local_states: list[dict]
            Local states. Required to propagate intermediate results.

        round_idx: int
            Index of the current round.

        clean_models: bool
            Whether to clean the models after the computation.

        Returns
        -------
        local_states: dict
            Local states. The new local state contains Cook's distances.

        dispersion_for_cook_shared_state: dict
            Shared state with the dispersion values for Cook's distances, in a
            "cooks_dispersions" key.

        round_idx: int
            The updated round index.
        """
        local_states, agg_shared_state, round_idx = self.compute_trim_mean(
            train_data_nodes,
            aggregation_node,
            local_states,
            round_idx,
            clean_models=clean_models,
            layer_used="normed_counts",
            mode="cooks",
            trim_ratio=None,
            n_iter=self.trimmed_mean_num_iter,
        )

        local_states, shared_states, round_idx = local_step(
            local_method=self.local_compute_sqerror,
            train_data_nodes=train_data_nodes,
            output_local_states=local_states,
            input_local_states=local_states,
            input_shared_state=agg_shared_state,
            aggregation_id=aggregation_node.organization_id,
            description="Compute local sqerror",
            round_idx=round_idx,
            clean_models=clean_models,
        )

        local_states, agg_shared_state, round_idx = self.compute_trim_mean(
            train_data_nodes,
            aggregation_node,
            local_states,
            round_idx,
            clean_models=clean_models,
            layer_used="sqerror",
            mode="cooks",
            trim_ratio=None,
            n_iter=self.trimmed_mean_num_iter,
        )

        local_states, shared_states, round_idx = local_step(
            local_method=self.local_get_normed_count_means,
            train_data_nodes=train_data_nodes,
            output_local_states=local_states,
            input_local_states=local_states,
            input_shared_state=agg_shared_state,
            aggregation_id=aggregation_node.organization_id,
            description="Get normed count means",
            round_idx=round_idx,
            clean_models=clean_models,
        )

        dispersion_for_cook_shared_state, round_idx = aggregation_step(
            aggregation_method=self.agg_compute_dispersion_for_cook,
            train_data_nodes=train_data_nodes,
            aggregation_node=aggregation_node,
            input_shared_states=shared_states,
            description="Compute dispersion for Cook distances",
            round_idx=round_idx,
            clean_models=clean_models,
        )

        return local_states, dispersion_for_cook_shared_state, round_idx

compute_cook_distance(train_data_nodes, aggregation_node, local_states, round_idx, clean_models)

Compute Cook's distances.

Parameters:

Name Type Description Default
train_data_nodes

list of TrainDataNode.

required
aggregation_node

The aggregation node.

required
local_states

Local states. Required to propagate intermediate results.

required
round_idx

Index of the current round.

required
clean_models

Whether to clean the models after the computation.

required

Returns:

Name Type Description
local_states dict

Local states. The new local state contains Cook's distances.

dispersion_for_cook_shared_state dict

Shared state with the dispersion values for Cook's distances, in a "cooks_dispersions" key.

round_idx int

The updated round index.

Source code in fedpydeseq2/core/deseq2_core/compute_cook_distance/compute_cook_distance.py
@log_organisation_method
def compute_cook_distance(
    self,
    train_data_nodes,
    aggregation_node,
    local_states,
    round_idx,
    clean_models,
):
    """Compute Cook's distances.

    Parameters
    ----------
    train_data_nodes: list
        list of TrainDataNode.

    aggregation_node: AggregationNode
        The aggregation node.

    local_states: list[dict]
        Local states. Required to propagate intermediate results.

    round_idx: int
        Index of the current round.

    clean_models: bool
        Whether to clean the models after the computation.

    Returns
    -------
    local_states: dict
        Local states. The new local state contains Cook's distances.

    dispersion_for_cook_shared_state: dict
        Shared state with the dispersion values for Cook's distances, in a
        "cooks_dispersions" key.

    round_idx: int
        The updated round index.
    """
    local_states, agg_shared_state, round_idx = self.compute_trim_mean(
        train_data_nodes,
        aggregation_node,
        local_states,
        round_idx,
        clean_models=clean_models,
        layer_used="normed_counts",
        mode="cooks",
        trim_ratio=None,
        n_iter=self.trimmed_mean_num_iter,
    )

    local_states, shared_states, round_idx = local_step(
        local_method=self.local_compute_sqerror,
        train_data_nodes=train_data_nodes,
        output_local_states=local_states,
        input_local_states=local_states,
        input_shared_state=agg_shared_state,
        aggregation_id=aggregation_node.organization_id,
        description="Compute local sqerror",
        round_idx=round_idx,
        clean_models=clean_models,
    )

    local_states, agg_shared_state, round_idx = self.compute_trim_mean(
        train_data_nodes,
        aggregation_node,
        local_states,
        round_idx,
        clean_models=clean_models,
        layer_used="sqerror",
        mode="cooks",
        trim_ratio=None,
        n_iter=self.trimmed_mean_num_iter,
    )

    local_states, shared_states, round_idx = local_step(
        local_method=self.local_get_normed_count_means,
        train_data_nodes=train_data_nodes,
        output_local_states=local_states,
        input_local_states=local_states,
        input_shared_state=agg_shared_state,
        aggregation_id=aggregation_node.organization_id,
        description="Get normed count means",
        round_idx=round_idx,
        clean_models=clean_models,
    )

    dispersion_for_cook_shared_state, round_idx = aggregation_step(
        aggregation_method=self.agg_compute_dispersion_for_cook,
        train_data_nodes=train_data_nodes,
        aggregation_node=aggregation_node,
        input_shared_states=shared_states,
        description="Compute dispersion for Cook distances",
        round_idx=round_idx,
        clean_models=clean_models,
    )

    return local_states, dispersion_for_cook_shared_state, round_idx

substeps

AggComputeDispersionForCook

Compute the dispersion for Cook's distance calculation.

Source code in fedpydeseq2/core/deseq2_core/compute_cook_distance/substeps.py
class AggComputeDispersionForCook:
    """Compute the dispersion for Cook's distance calculation."""

    @remote
    @log_remote
    @prepare_cooks_agg
    def agg_compute_dispersion_for_cook(
        self,
        shared_states: list[dict],
    ) -> dict:
        """Compute the dispersion for Cook's distance calculation.

        Parameters
        ----------
        shared_states : list[dict]
            list of shared states with the following keys:
            - mean_normed_counts: mean of the normalized counts
            - n_samples: number of samples
            - varEst: variance estimate

        Returns
        -------
        dict
            Because it is decorated, the dictionary will have the following key:
            - cooks_dispersions: dispersion values
        """
        return {}

agg_compute_dispersion_for_cook(shared_states)

Compute the dispersion for Cook's distance calculation.

Parameters:

Name Type Description Default
shared_states list[dict]

list of shared states with the following keys: - mean_normed_counts: mean of the normalized counts - n_samples: number of samples - varEst: variance estimate

required

Returns:

Type Description
dict

Because it is decorated, the dictionary will have the following key: - cooks_dispersions: dispersion values

Source code in fedpydeseq2/core/deseq2_core/compute_cook_distance/substeps.py
@remote
@log_remote
@prepare_cooks_agg
def agg_compute_dispersion_for_cook(
    self,
    shared_states: list[dict],
) -> dict:
    """Compute the dispersion for Cook's distance calculation.

    Parameters
    ----------
    shared_states : list[dict]
        list of shared states with the following keys:
        - mean_normed_counts: mean of the normalized counts
        - n_samples: number of samples
        - varEst: variance estimate

    Returns
    -------
    dict
        Because it is decorated, the dictionary will have the following key:
        - cooks_dispersions: dispersion values
    """
    return {}

LocComputeSqerror

Compute the squared error between the normalized counts and the trimmed mean.

Source code in fedpydeseq2/core/deseq2_core/compute_cook_distance/substeps.py
class LocComputeSqerror:
    """Compute the squared error between the normalized counts and the trimmed mean."""

    local_adata: AnnData

    @remote_data
    @log_remote_data
    @reconstruct_adatas
    def local_compute_sqerror(
        self,
        data_from_opener,
        shared_state=dict,
    ) -> None:
        """Compute the squared error between the normalized counts and the trimmed mean.

        Parameters
        ----------
        data_from_opener : ad.AnnData
            AnnData returned by the opener. Not used.

        shared_state : dict, optional
            Results to save in the local states.
        """
        cell_means = shared_state["trimmed_mean_normed_counts"]
        if isinstance(cell_means, pd.DataFrame):
            cell_means.index = self.local_adata.var_names
            self.local_adata.varm["cell_means"] = cell_means
        else:
            # In this case, the cell means are not computed per
            # level but overall
            self.local_adata.varm["cell_means"] = cell_means
        set_sqerror_layer(self.local_adata)

local_compute_sqerror(data_from_opener, shared_state=dict)

Compute the squared error between the normalized counts and the trimmed mean.

Parameters:

Name Type Description Default
data_from_opener AnnData

AnnData returned by the opener. Not used.

required
shared_state dict

Results to save in the local states.

dict
Source code in fedpydeseq2/core/deseq2_core/compute_cook_distance/substeps.py
@remote_data
@log_remote_data
@reconstruct_adatas
def local_compute_sqerror(
    self,
    data_from_opener,
    shared_state=dict,
) -> None:
    """Compute the squared error between the normalized counts and the trimmed mean.

    Parameters
    ----------
    data_from_opener : ad.AnnData
        AnnData returned by the opener. Not used.

    shared_state : dict, optional
        Results to save in the local states.
    """
    cell_means = shared_state["trimmed_mean_normed_counts"]
    if isinstance(cell_means, pd.DataFrame):
        cell_means.index = self.local_adata.var_names
        self.local_adata.varm["cell_means"] = cell_means
    else:
        # In this case, the cell means are not computed per
        # level but overall
        self.local_adata.varm["cell_means"] = cell_means
    set_sqerror_layer(self.local_adata)

LocGetNormedCounts

Get the mean of the normalized counts.

Source code in fedpydeseq2/core/deseq2_core/compute_cook_distance/substeps.py
class LocGetNormedCounts:
    """Get the mean of the normalized counts."""

    local_adata: AnnData

    @remote_data
    @log_remote_data
    @reconstruct_adatas
    @prepare_cooks_local
    def local_get_normed_count_means(
        self,
        data_from_opener,
        shared_state=dict,
    ) -> dict:
        """Send local normed counts means.

        Parameters
        ----------
        data_from_opener : ad.AnnData
            AnnData returned by the opener. Not used.

        shared_state : dict, optional
            Dictionary with the following keys:
            - varEst: variance estimate for Cook's distance calculation

        Returns
        -------
        dict
            Because of the decorator, dictionary with the following keys:
            - mean_normed_counts: mean of the normalized counts
            - n_samples: number of samples
            - varEst: variance estimate
        """
        return {}

local_get_normed_count_means(data_from_opener, shared_state=dict)

Send local normed counts means.

Parameters:

Name Type Description Default
data_from_opener AnnData

AnnData returned by the opener. Not used.

required
shared_state dict

Dictionary with the following keys: - varEst: variance estimate for Cook's distance calculation

dict

Returns:

Type Description
dict

Because of the decorator, dictionary with the following keys: - mean_normed_counts: mean of the normalized counts - n_samples: number of samples - varEst: variance estimate

Source code in fedpydeseq2/core/deseq2_core/compute_cook_distance/substeps.py
@remote_data
@log_remote_data
@reconstruct_adatas
@prepare_cooks_local
def local_get_normed_count_means(
    self,
    data_from_opener,
    shared_state=dict,
) -> dict:
    """Send local normed counts means.

    Parameters
    ----------
    data_from_opener : ad.AnnData
        AnnData returned by the opener. Not used.

    shared_state : dict, optional
        Dictionary with the following keys:
        - varEst: variance estimate for Cook's distance calculation

    Returns
    -------
    dict
        Because of the decorator, dictionary with the following keys:
        - mean_normed_counts: mean of the normalized counts
        - n_samples: number of samples
        - varEst: variance estimate
    """
    return {}