Skip to content

Federated IRLS

Module which contains the Mixin in charge of performing FedIRLS.

fed_irls

Module containing the ComputeLFC method.

FedIRLS

Bases: LocMakeIRLSSummands, AggMakeIRLSUpdate

Mixin class to implement the LFC computation algorithm.

The goal of this class is to implement the IRLS algorithm specifically applied to the negative binomial distribution, with fixed dispersion parameter (only the mean parameter, expressed as the exponential of the log fold changes times the design matrix, is estimated). This algorithm is caught with another method on the genes on which it fails.

To the best of our knowledge, there is no explicit implementation of IRLS for the negative binomial in a federated setting. However, the steps of IRLS are akin to the ones of a Newton-Raphson algorithm, with the difference that the Hessian matrix is replaced by the Fisher information matrix.

Let us recall the steps of the IRLS algorithm for one gene (this method then implements these iterations for all genes in parallell). We want to estimate the log fold changes :math:\beta from the counts :math:y and the design matrix :math:X. The negative binomial likelihood is given by:

.. math:: \mathcal{L}(\beta) = \sum_{i=1}^n \left( y_i \log(\mu_i) - (y_i + \alpha^{-1}) \log(\mu_i + \alpha^{-1}) \right) + \text{const}(y, \alpha)

where :math:\mu_i = \gamma_i\exp(X_i \cdot \beta) and :math:\alpha is the dispersion parameter.

Given an iterate :math:\beta_k, the IRLS algorithm computes the next iterate :math:\beta_{k+1} as follows.

First, we compute the mean parameter :math:\mu_k from the current iterate, using the formula of the log fold changes:

.. math:: (\mu_{k})_i = \gamma_i \exp(X_i \cdot \beta_k)

In practice, we trim the values of :math:\mu_k to a minimum value to ensure numerical stability.

Then, we compute the weight matrix :math:W_k from the current iterate :math:\beta_k, which is a diagonal matrix with diagonal elements:

.. math:: (W_k){ii} = \frac{\mu}}{1 + \mu_{k,i} \alpha

where :math:\alpha is the dispersion parameter. This weight matrix is used to compute both the estimated variance (or hat matrix) and the feature vector :math:z_k:

.. math:: z_k = \log\left(\frac{\mu_k}{\gamma}\right) + \frac{y - \mu_k}{\mu_k}

The estimated variance is given by:

.. math:: H_k = X^T W_k X

The update step is then given by:

.. math:: \beta_{k+1} = (H_k)^{-1} X^T W_k z_k

This is akin to the Newton-Raphson algorithm, with the Hessian matrix replaced by the Fisher information, and the gradient replaced by the feature vector.

Methods:

Name Description
run_fed_irls

Run the IRLS algorithm.

Source code in fedpydeseq2/core/fed_algorithms/fed_irls/fed_irls.py
class FedIRLS(
    LocMakeIRLSSummands,
    AggMakeIRLSUpdate,
):
    r"""Mixin class to implement the LFC computation algorithm.

    The goal of this class is to implement the IRLS algorithm specifically applied
    to the negative binomial distribution, with fixed dispersion parameter (only
    the mean parameter, expressed as the exponential of the log fold changes times
    the design matrix, is estimated). This algorithm is caught with another method on
    the genes on which it fails.

    To the best of our knowledge, there is no explicit implementation of IRLS for the
    negative binomial in a federated setting. However, the steps of IRLS are akin
    to the ones of a Newton-Raphson algorithm, with the difference that the Hessian
    matrix is replaced by the Fisher information matrix.

    Let us recall the steps of the IRLS algorithm for one gene (this method then
    implements these iterations for all genes in parallell).
    We want to estimate the log fold changes :math:`\beta` from the counts :math:`y`
    and the design matrix :math:`X`. The negative binomial likelihood is given by:

    .. math::
        \mathcal{L}(\beta) = \sum_{i=1}^n \left( y_i \log(\mu_i) -
        (y_i + \alpha^{-1}) \log(\mu_i + \alpha^{-1}) \right) + \text{const}(y, \alpha)

    where :math:`\mu_i = \gamma_i\exp(X_i \cdot \beta)` and :math:`\alpha` is
    the dispersion parameter.

    Given an iterate :math:`\beta_k`, the IRLS algorithm computes the next iterate
    :math:`\beta_{k+1}` as follows.

    First, we compute the mean parameter :math:`\mu_k` from the current iterate, using
    the formula of the log fold changes:

    .. math::
        (\mu_{k})_i = \gamma_i \exp(X_i \cdot \beta_k)

    In practice, we trim the values of :math:`\mu_k` to a minimum value to ensure
    numerical stability.

    Then, we compute the weight matrix :math:`W_k` from the current iterate
    :math:`\beta_k`, which is a diagonal matrix with diagonal elements:

    .. math::
        (W_k)_{ii} = \frac{\mu_{k,i}}{1 + \mu_{k,i} \alpha}

    where :math:`\alpha` is the dispersion parameter.
    This weight matrix is used to compute both the estimated variance (or hat matrix)
    and the feature vector :math:`z_k`:

    .. math::
        z_k = \log\left(\frac{\mu_k}{\gamma}\right) + \frac{y - \mu_k}{\mu_k}

    The estimated variance is given by:

    .. math::
        H_k = X^T W_k X

    The update step is then given by:

    .. math::
        \beta_{k+1} = (H_k)^{-1} X^T W_k z_k

    This is akin to the Newton-Raphson algorithm, with the
    Hessian matrix replaced by the Fisher information, and the gradient replaced by the
    feature vector.

    Methods
    -------
    run_fed_irls
        Run the IRLS algorithm.

    """

    def run_fed_irls(
        self,
        train_data_nodes: list,
        aggregation_node: AggregationNode,
        local_states: dict,
        input_shared_state: dict,
        round_idx: int,
        clean_models: bool = True,
        refit_mode: bool = False,
    ):
        """Run the IRLS algorithm.

        Parameters
        ----------
        train_data_nodes: list
            List of TrainDataNode.

        aggregation_node: AggregationNode
            The aggregation node.

        local_states: dict
            Local states. Required to propagate intermediate results.

        input_shared_state: dict
            Shared state with the following keys:
            - beta: ndarray
                The current beta, of shape (n_non_zero_genes, n_params).
            - irls_diverged_mask: ndarray
                A boolean mask indicating if fed avg should be used for a given gene
                (shape: (n_non_zero_genes,)).
            - irls_mask: ndarray
                A boolean mask indicating if IRLS should be used for a given gene
                (shape: (n_non_zero_genes,)).
            - global_nll: ndarray
                The global_nll of the current beta from the previous beta, of shape\
                (n_non_zero_genes,).
            - round_number_irls: int
                The current round number of the IRLS algorithm.

        round_idx: int
            The current round.

        clean_models: bool
            If True, the models are cleaned.

        refit_mode: bool
            Whether to run on `refit_adata`s instead of `local_adata`s.
            (default: False).

        Returns
        -------
        local_states: dict
            Local states. Required to propagate intermediate results.

        global_irls_summands_nlls_shared_state: dict
            Shared states containing the final IRLS results.
            It contains nothing for now.
            - beta: ndarray
                The current beta, of shape (n_non_zero_genes, n_params).
            - irls_diverged_mask: ndarray
                A boolean mask indicating if fed avg should be used for a given gene
                (shape: (n_non_zero_genes,)).
            - irls_mask: ndarray
                A boolean mask indicating if IRLS should be used for a given gene
                (shape: (n_non_zero_genes,)).
            - global_nll: ndarray
                The global_nll of the current beta from the previous beta, of shape\
                (n_non_zero_genes,).
            - round_number_irls: int
                The current round number of the IRLS algorithm.

        round_idx: int
            The updated round index.

        """
        #### ---- Main training loop ---- #####

        global_irls_summands_nlls_shared_state = input_shared_state

        for _ in range(self.irls_num_iter + 1):
            # ---- Compute local IRLS summands and nlls ---- #

            (
                local_states,
                local_irls_summands_nlls_shared_states,
                round_idx,
            ) = local_step(
                local_method=self.make_local_irls_summands_and_nlls,
                train_data_nodes=train_data_nodes,
                output_local_states=local_states,
                round_idx=round_idx,
                input_local_states=local_states,
                input_shared_state=global_irls_summands_nlls_shared_state,
                aggregation_id=aggregation_node.organization_id,
                description="Compute local IRLS summands and nlls.",
                clean_models=clean_models,
                method_params={"refit_mode": refit_mode},
            )

            # ---- Compute global IRLS update and nlls ---- #

            global_irls_summands_nlls_shared_state, round_idx = aggregation_step(
                aggregation_method=self.make_global_irls_update,
                train_data_nodes=train_data_nodes,
                aggregation_node=aggregation_node,
                input_shared_states=local_irls_summands_nlls_shared_states,
                round_idx=round_idx,
                description="Update the log fold changes and nlls in IRLS.",
                clean_models=clean_models,
            )

        return local_states, global_irls_summands_nlls_shared_state, round_idx

run_fed_irls(train_data_nodes, aggregation_node, local_states, input_shared_state, round_idx, clean_models=True, refit_mode=False)

Run the IRLS algorithm.

Parameters:

Name Type Description Default
train_data_nodes list

List of TrainDataNode.

required
aggregation_node AggregationNode

The aggregation node.

required
local_states dict

Local states. Required to propagate intermediate results.

required
input_shared_state dict

Shared state with the following keys: - beta: ndarray The current beta, of shape (n_non_zero_genes, n_params). - irls_diverged_mask: ndarray A boolean mask indicating if fed avg should be used for a given gene (shape: (n_non_zero_genes,)). - irls_mask: ndarray A boolean mask indicating if IRLS should be used for a given gene (shape: (n_non_zero_genes,)). - global_nll: ndarray The global_nll of the current beta from the previous beta, of shape (n_non_zero_genes,). - round_number_irls: int The current round number of the IRLS algorithm.

required
round_idx int

The current round.

required
clean_models bool

If True, the models are cleaned.

True
refit_mode bool

Whether to run on refit_adatas instead of local_adatas. (default: False).

False

Returns:

Name Type Description
local_states dict

Local states. Required to propagate intermediate results.

global_irls_summands_nlls_shared_state dict

Shared states containing the final IRLS results. It contains nothing for now. - beta: ndarray The current beta, of shape (n_non_zero_genes, n_params). - irls_diverged_mask: ndarray A boolean mask indicating if fed avg should be used for a given gene (shape: (n_non_zero_genes,)). - irls_mask: ndarray A boolean mask indicating if IRLS should be used for a given gene (shape: (n_non_zero_genes,)). - global_nll: ndarray The global_nll of the current beta from the previous beta, of shape (n_non_zero_genes,). - round_number_irls: int The current round number of the IRLS algorithm.

round_idx int

The updated round index.

Source code in fedpydeseq2/core/fed_algorithms/fed_irls/fed_irls.py
def run_fed_irls(
    self,
    train_data_nodes: list,
    aggregation_node: AggregationNode,
    local_states: dict,
    input_shared_state: dict,
    round_idx: int,
    clean_models: bool = True,
    refit_mode: bool = False,
):
    """Run the IRLS algorithm.

    Parameters
    ----------
    train_data_nodes: list
        List of TrainDataNode.

    aggregation_node: AggregationNode
        The aggregation node.

    local_states: dict
        Local states. Required to propagate intermediate results.

    input_shared_state: dict
        Shared state with the following keys:
        - beta: ndarray
            The current beta, of shape (n_non_zero_genes, n_params).
        - irls_diverged_mask: ndarray
            A boolean mask indicating if fed avg should be used for a given gene
            (shape: (n_non_zero_genes,)).
        - irls_mask: ndarray
            A boolean mask indicating if IRLS should be used for a given gene
            (shape: (n_non_zero_genes,)).
        - global_nll: ndarray
            The global_nll of the current beta from the previous beta, of shape\
            (n_non_zero_genes,).
        - round_number_irls: int
            The current round number of the IRLS algorithm.

    round_idx: int
        The current round.

    clean_models: bool
        If True, the models are cleaned.

    refit_mode: bool
        Whether to run on `refit_adata`s instead of `local_adata`s.
        (default: False).

    Returns
    -------
    local_states: dict
        Local states. Required to propagate intermediate results.

    global_irls_summands_nlls_shared_state: dict
        Shared states containing the final IRLS results.
        It contains nothing for now.
        - beta: ndarray
            The current beta, of shape (n_non_zero_genes, n_params).
        - irls_diverged_mask: ndarray
            A boolean mask indicating if fed avg should be used for a given gene
            (shape: (n_non_zero_genes,)).
        - irls_mask: ndarray
            A boolean mask indicating if IRLS should be used for a given gene
            (shape: (n_non_zero_genes,)).
        - global_nll: ndarray
            The global_nll of the current beta from the previous beta, of shape\
            (n_non_zero_genes,).
        - round_number_irls: int
            The current round number of the IRLS algorithm.

    round_idx: int
        The updated round index.

    """
    #### ---- Main training loop ---- #####

    global_irls_summands_nlls_shared_state = input_shared_state

    for _ in range(self.irls_num_iter + 1):
        # ---- Compute local IRLS summands and nlls ---- #

        (
            local_states,
            local_irls_summands_nlls_shared_states,
            round_idx,
        ) = local_step(
            local_method=self.make_local_irls_summands_and_nlls,
            train_data_nodes=train_data_nodes,
            output_local_states=local_states,
            round_idx=round_idx,
            input_local_states=local_states,
            input_shared_state=global_irls_summands_nlls_shared_state,
            aggregation_id=aggregation_node.organization_id,
            description="Compute local IRLS summands and nlls.",
            clean_models=clean_models,
            method_params={"refit_mode": refit_mode},
        )

        # ---- Compute global IRLS update and nlls ---- #

        global_irls_summands_nlls_shared_state, round_idx = aggregation_step(
            aggregation_method=self.make_global_irls_update,
            train_data_nodes=train_data_nodes,
            aggregation_node=aggregation_node,
            input_shared_states=local_irls_summands_nlls_shared_states,
            round_idx=round_idx,
            description="Update the log fold changes and nlls in IRLS.",
            clean_models=clean_models,
        )

    return local_states, global_irls_summands_nlls_shared_state, round_idx

substeps

Module to implement the substeps for the fitting of log fold changes.

This module contains all these substeps as mixin classes.

AggMakeIRLSUpdate

Mixin class to aggregate IRLS summands.

Please refer to the method make_local_irls_summands_and_nlls for more.

Attributes:

Name Type Description
num_jobs int

The number of cpus to use.

joblib_verbosity int

The verbosity of the joblib backend.

joblib_backend str

The backend to use for the joblib parallelization.

irls_batch_size int

The batch size to use for the IRLS algorithm.

max_beta float

The maximum value for the beta parameter.

beta_tol float

The tolerance for the beta parameter.

irls_num_iter int

The number of iterations for the IRLS algorithm.

Methods:

Name Description
make_global_irls_update

A remote method. Aggregates the local quantities to create the global IRLS update. It also updates the masks indicating which genes have diverged or converged according to the deviance.

Source code in fedpydeseq2/core/fed_algorithms/fed_irls/substeps.py
class AggMakeIRLSUpdate:
    """Mixin class to aggregate IRLS summands.

    Please refer to the method make_local_irls_summands_and_nlls for more.

    Attributes
    ----------
    num_jobs : int
        The number of cpus to use.
    joblib_verbosity : int
        The verbosity of the joblib backend.
    joblib_backend : str
        The backend to use for the joblib parallelization.
    irls_batch_size : int
        The batch size to use for the IRLS algorithm.
    max_beta : float
        The maximum value for the beta parameter.
    beta_tol : float
        The tolerance for the beta parameter.
    irls_num_iter : int
        The number of iterations for the IRLS algorithm.

    Methods
    -------
    make_global_irls_update
        A remote method. Aggregates the local quantities to create
        the global IRLS update. It also updates the masks indicating which genes
        have diverged or converged according to the deviance.

    """

    num_jobs: int
    joblib_verbosity: int
    joblib_backend: str
    irls_batch_size: int
    max_beta: float
    beta_tol: float
    irls_num_iter: int

    @remote
    @log_remote
    def make_global_irls_update(self, shared_states: list[dict]) -> dict[str, Any]:
        """Make the summands for the IRLS algorithm.

        The role of this function is twofold.

        1) It computes the global_nll and updates the masks according to the deviance,
        for the beta values that have been computed in the previous round.

        2) It aggregates the local hat matrix and features to solve the linear system
        and get the new beta values.

        Parameters
        ----------
        shared_states: list[dict]
            A list of dictionaries containing the following
            keys:
            - local_hat_matrix: ndarray
                The local hat matrix, of shape (n_irls_genes, n_params, n_params).
                n_irsl_genes is the number of genes that are still active (non zero
                gene names on the irls_mask).
            - local_features: ndarray
                The local features, of shape (n_irls_genes, n_params).
            - irls_diverged_mask: ndarray
                A boolean mask indicating if fed avg should be used for a given gene
                (shape: (n_non_zero_genes,)).
            - irls_mask: ndarray
                A boolean mask indicating if IRLS should be used for a given gene
                (shape: (n_non_zero_genes,)).
            - global_nll: ndarray
                The global_nll of the current beta from the previous beta, of shape\
                (n_non_zero_genes,).
            - round_number_irls: int
                The current round number of the IRLS algorithm.

        Returns
        -------
        dict[str, Any]
            A dictionary containing all the necessary info to run IRLS.
            It contains the following fields:
            - beta: ndarray
                The log fold changes, of shape (n_non_zero_genes, n_params).
            - irls_diverged_mask: ndarray
                A boolean mask indicating if fed avg should be used for a given gene
                (shape: (n_non_zero_genes,)).
            - irls_mask: ndarray
                A boolean mask indicating if IRLS should be used for a given gene
                (shape: (n_non_zero_genes,)).
            - global_nll: ndarray
                The global_nll of the current beta from the previous beta, of shape\
                (n_non_zero_genes,).
            - round_number_irls: int
                The current round number of the IRLS algorithm.


        """
        # Load main params from the first state
        beta = shared_states[0]["beta"]
        irls_mask = shared_states[0]["irls_mask"]
        irls_diverged_mask = shared_states[0]["irls_diverged_mask"]
        global_nll = shared_states[0]["global_nll"]
        round_number_irls = shared_states[0]["round_number_irls"]

        # ---- Step 0: Aggregate the local hat matrix, features and global_nll ---- #

        global_hat_matrix = sum([state["local_hat_matrix"] for state in shared_states])
        global_features = sum([state["local_features"] for state in shared_states])
        global_nll_on_irls_mask = sum([state["local_nll"] for state in shared_states])

        # ---- Step 1: update global_nll and masks ---- #

        # The first round needs to be handled separately
        if round_number_irls == 0:
            # In that case, the irls_masks consists in all True values
            # We only need set the initial global_nll
            global_nll = global_nll_on_irls_mask

        else:
            old_global_nll = global_nll.copy()
            old_irls_mask = irls_mask.copy()

            global_nll[irls_mask] = global_nll_on_irls_mask

            # Set the new masks with the dev ratio and beta values
            deviance_ratio = np.abs(2 * global_nll - 2 * old_global_nll) / (
                np.abs(2 * global_nll) + 0.1
            )
            irls_diverged_mask = irls_diverged_mask | (
                np.abs(beta) > self.max_beta
            ).any(axis=1)

            irls_mask = irls_mask & (deviance_ratio > self.beta_tol)
            irls_mask = irls_mask & ~irls_diverged_mask
            new_mask_in_old_mask = (irls_mask & old_irls_mask)[old_irls_mask]
            global_hat_matrix = global_hat_matrix[new_mask_in_old_mask]
            global_features = global_features[new_mask_in_old_mask]

        if round_number_irls == self.irls_num_iter:
            # In this case, we must prepare the switch to fed prox newton
            return {
                "beta": beta,
                "irls_diverged_mask": irls_diverged_mask,
                "irls_mask": irls_mask,
                "global_nll": global_nll,
                "round_number_irls": round_number_irls,
            }

        # ---- Step 2: Solve the system to compute beta ---- #

        ridge_factor = np.diag(np.repeat(1e-6, global_hat_matrix.shape[1]))
        with parallel_backend(self.joblib_backend):
            res = Parallel(n_jobs=self.num_jobs, verbose=self.joblib_verbosity)(
                delayed(np.linalg.solve)(
                    global_hat_matrix[i : i + self.irls_batch_size] + ridge_factor,
                    global_features[i : i + self.irls_batch_size],
                )
                for i in range(0, len(global_hat_matrix), self.irls_batch_size)
            )
        if len(res) > 0:
            beta_hat = np.concatenate(res)
        else:
            beta_hat = np.zeros((0, global_hat_matrix.shape[1]))

        # TODO :  it would be cleaner to pass an update, which is None at the first
        #  round. That way we do not update beta in a different step its evaluation.

        # Update the beta
        beta[irls_mask] = beta_hat

        round_number_irls = round_number_irls + 1

        return {
            "beta": beta,
            "irls_diverged_mask": irls_diverged_mask,
            "irls_mask": irls_mask,
            "global_nll": global_nll,
            "round_number_irls": round_number_irls,
        }

make_global_irls_update(shared_states)

Make the summands for the IRLS algorithm.

The role of this function is twofold.

1) It computes the global_nll and updates the masks according to the deviance, for the beta values that have been computed in the previous round.

2) It aggregates the local hat matrix and features to solve the linear system and get the new beta values.

Parameters:

Name Type Description Default
shared_states list[dict]

A list of dictionaries containing the following keys: - local_hat_matrix: ndarray The local hat matrix, of shape (n_irls_genes, n_params, n_params). n_irsl_genes is the number of genes that are still active (non zero gene names on the irls_mask). - local_features: ndarray The local features, of shape (n_irls_genes, n_params). - irls_diverged_mask: ndarray A boolean mask indicating if fed avg should be used for a given gene (shape: (n_non_zero_genes,)). - irls_mask: ndarray A boolean mask indicating if IRLS should be used for a given gene (shape: (n_non_zero_genes,)). - global_nll: ndarray The global_nll of the current beta from the previous beta, of shape (n_non_zero_genes,). - round_number_irls: int The current round number of the IRLS algorithm.

required

Returns:

Type Description
dict[str, Any]

A dictionary containing all the necessary info to run IRLS. It contains the following fields: - beta: ndarray The log fold changes, of shape (n_non_zero_genes, n_params). - irls_diverged_mask: ndarray A boolean mask indicating if fed avg should be used for a given gene (shape: (n_non_zero_genes,)). - irls_mask: ndarray A boolean mask indicating if IRLS should be used for a given gene (shape: (n_non_zero_genes,)). - global_nll: ndarray The global_nll of the current beta from the previous beta, of shape (n_non_zero_genes,). - round_number_irls: int The current round number of the IRLS algorithm.

Source code in fedpydeseq2/core/fed_algorithms/fed_irls/substeps.py
@remote
@log_remote
def make_global_irls_update(self, shared_states: list[dict]) -> dict[str, Any]:
    """Make the summands for the IRLS algorithm.

    The role of this function is twofold.

    1) It computes the global_nll and updates the masks according to the deviance,
    for the beta values that have been computed in the previous round.

    2) It aggregates the local hat matrix and features to solve the linear system
    and get the new beta values.

    Parameters
    ----------
    shared_states: list[dict]
        A list of dictionaries containing the following
        keys:
        - local_hat_matrix: ndarray
            The local hat matrix, of shape (n_irls_genes, n_params, n_params).
            n_irsl_genes is the number of genes that are still active (non zero
            gene names on the irls_mask).
        - local_features: ndarray
            The local features, of shape (n_irls_genes, n_params).
        - irls_diverged_mask: ndarray
            A boolean mask indicating if fed avg should be used for a given gene
            (shape: (n_non_zero_genes,)).
        - irls_mask: ndarray
            A boolean mask indicating if IRLS should be used for a given gene
            (shape: (n_non_zero_genes,)).
        - global_nll: ndarray
            The global_nll of the current beta from the previous beta, of shape\
            (n_non_zero_genes,).
        - round_number_irls: int
            The current round number of the IRLS algorithm.

    Returns
    -------
    dict[str, Any]
        A dictionary containing all the necessary info to run IRLS.
        It contains the following fields:
        - beta: ndarray
            The log fold changes, of shape (n_non_zero_genes, n_params).
        - irls_diverged_mask: ndarray
            A boolean mask indicating if fed avg should be used for a given gene
            (shape: (n_non_zero_genes,)).
        - irls_mask: ndarray
            A boolean mask indicating if IRLS should be used for a given gene
            (shape: (n_non_zero_genes,)).
        - global_nll: ndarray
            The global_nll of the current beta from the previous beta, of shape\
            (n_non_zero_genes,).
        - round_number_irls: int
            The current round number of the IRLS algorithm.


    """
    # Load main params from the first state
    beta = shared_states[0]["beta"]
    irls_mask = shared_states[0]["irls_mask"]
    irls_diverged_mask = shared_states[0]["irls_diverged_mask"]
    global_nll = shared_states[0]["global_nll"]
    round_number_irls = shared_states[0]["round_number_irls"]

    # ---- Step 0: Aggregate the local hat matrix, features and global_nll ---- #

    global_hat_matrix = sum([state["local_hat_matrix"] for state in shared_states])
    global_features = sum([state["local_features"] for state in shared_states])
    global_nll_on_irls_mask = sum([state["local_nll"] for state in shared_states])

    # ---- Step 1: update global_nll and masks ---- #

    # The first round needs to be handled separately
    if round_number_irls == 0:
        # In that case, the irls_masks consists in all True values
        # We only need set the initial global_nll
        global_nll = global_nll_on_irls_mask

    else:
        old_global_nll = global_nll.copy()
        old_irls_mask = irls_mask.copy()

        global_nll[irls_mask] = global_nll_on_irls_mask

        # Set the new masks with the dev ratio and beta values
        deviance_ratio = np.abs(2 * global_nll - 2 * old_global_nll) / (
            np.abs(2 * global_nll) + 0.1
        )
        irls_diverged_mask = irls_diverged_mask | (
            np.abs(beta) > self.max_beta
        ).any(axis=1)

        irls_mask = irls_mask & (deviance_ratio > self.beta_tol)
        irls_mask = irls_mask & ~irls_diverged_mask
        new_mask_in_old_mask = (irls_mask & old_irls_mask)[old_irls_mask]
        global_hat_matrix = global_hat_matrix[new_mask_in_old_mask]
        global_features = global_features[new_mask_in_old_mask]

    if round_number_irls == self.irls_num_iter:
        # In this case, we must prepare the switch to fed prox newton
        return {
            "beta": beta,
            "irls_diverged_mask": irls_diverged_mask,
            "irls_mask": irls_mask,
            "global_nll": global_nll,
            "round_number_irls": round_number_irls,
        }

    # ---- Step 2: Solve the system to compute beta ---- #

    ridge_factor = np.diag(np.repeat(1e-6, global_hat_matrix.shape[1]))
    with parallel_backend(self.joblib_backend):
        res = Parallel(n_jobs=self.num_jobs, verbose=self.joblib_verbosity)(
            delayed(np.linalg.solve)(
                global_hat_matrix[i : i + self.irls_batch_size] + ridge_factor,
                global_features[i : i + self.irls_batch_size],
            )
            for i in range(0, len(global_hat_matrix), self.irls_batch_size)
        )
    if len(res) > 0:
        beta_hat = np.concatenate(res)
    else:
        beta_hat = np.zeros((0, global_hat_matrix.shape[1]))

    # TODO :  it would be cleaner to pass an update, which is None at the first
    #  round. That way we do not update beta in a different step its evaluation.

    # Update the beta
    beta[irls_mask] = beta_hat

    round_number_irls = round_number_irls + 1

    return {
        "beta": beta,
        "irls_diverged_mask": irls_diverged_mask,
        "irls_mask": irls_mask,
        "global_nll": global_nll,
        "round_number_irls": round_number_irls,
    }

LocMakeIRLSSummands

Mixin to make the summands for the IRLS algorithm.

Attributes:

Name Type Description
local_adata AnnData

The local AnnData object.

num_jobs int

The number of cpus to use.

joblib_verbosity int

The verbosity of the joblib backend.

joblib_backend str

The backend to use for the joblib parallelization.

irls_batch_size int

The batch size to use for the IRLS algorithm.

min_mu float

The minimum value for the mu parameter.

irls_num_iter int

The number of iterations for the IRLS algorithm.

Methods:

Name Description
make_local_irls_summands_and_nlls

A remote_data method. Makes the summands for the IRLS algorithm. It also passes on the necessary global quantities.

Source code in fedpydeseq2/core/fed_algorithms/fed_irls/substeps.py
class LocMakeIRLSSummands:
    """Mixin to make the summands for the IRLS algorithm.

    Attributes
    ----------
    local_adata : AnnData
        The local AnnData object.
    num_jobs : int
        The number of cpus to use.
    joblib_verbosity : int
        The verbosity of the joblib backend.
    joblib_backend : str
        The backend to use for the joblib parallelization.
    irls_batch_size : int
        The batch size to use for the IRLS algorithm.
    min_mu : float
        The minimum value for the mu parameter.
    irls_num_iter : int
        The number of iterations for the IRLS algorithm.

    Methods
    -------
    make_local_irls_summands_and_nlls
        A remote_data method. Makes the summands for the IRLS algorithm.
        It also passes on the necessary global quantities.

    """

    local_adata: AnnData
    refit_adata: AnnData
    num_jobs: int
    joblib_verbosity: int
    joblib_backend: str
    irls_batch_size: int
    min_mu: float
    irls_num_iter: int

    @remote_data
    @log_remote_data
    @reconstruct_adatas
    def make_local_irls_summands_and_nlls(
        self,
        data_from_opener: AnnData,
        shared_state: dict[str, Any],
        refit_mode: bool = False,
    ):
        """Make the summands for the IRLS algorithm.

        This functions does two main operations:

        1) It computes the summands for the beta update.
        2) It computes the local quantities to compute the global_nll
        of the current beta


        Parameters
        ----------
        data_from_opener : AnnData
            Not used.

        shared_state : dict
            A dictionary containing the following
            keys:
            - beta: ndarray
                The current beta, of shape (n_non_zero_genes, n_params).
            - irls_diverged_mask: ndarray
                A boolean mask indicating if fed avg should be used for a given gene
                (shape: (n_non_zero_genes,)).
            - irls_mask: ndarray
                A boolean mask indicating if IRLS should be used for a given gene
                (shape: (n_non_zero_genes,)).
            - global_nll: ndarray
                The global_nll of the current beta from the previous beta, of shape\
                (n_non_zero_genes,).
            - round_number_irls: int
                The current round number of the IRLS algorithm.

        refit_mode : bool
            Whether to run on `refit_adata`s instead of `local_adata`s.
            (default: False).

        Returns
        -------
        dict
            The state to share to the server.
            It contains the following fields:
            - beta: ndarray
                The current beta, of shape (n_non_zero_genes, n_params).
            - local_nll: ndarray
                The local nll of the current beta, of shape (n_irls_genes,).
            - local_hat_matrix: ndarray
                The local hat matrix, of shape (n_irls_genes, n_params, n_params).
                n_irsl_genes is the number of genes that are still active (non zero
                gene names on the irls_mask).
            - local_features: ndarray
                The local features, of shape (n_irls_genes, n_params).
            - irls_diverged_mask: ndarray
                A boolean mask indicating if fed avg should be used for a given gene
                (shape: (n_non_zero_genes,)).
            - irls_mask: ndarray
                A boolean mask indicating if IRLS should be used for a given gene
                (shape: (n_non_zero_genes,)).
            - global_nll: ndarray
                The global_nll of the current beta of shape
                (n_non_zero_genes,).
                This parameter is simply passed to the next shared state
            - round_number_irls: int
                The current round number of the IRLS algorithm.
                This round number is not updated here.

        """
        if refit_mode:
            adata = self.refit_adata
        else:
            adata = self.local_adata

        # Put all elements in the shared state in readable variables
        beta = shared_state["beta"]
        irls_mask = shared_state["irls_mask"]
        irls_diverged_mask = shared_state["irls_diverged_mask"]
        global_nll = shared_state["global_nll"]
        round_number_irls = shared_state["round_number_irls"]

        # Get the quantitie stored in the adata
        disp_param_name = adata.uns["_irls_disp_param_name"]

        # If this is the first round, save the beta init in a field of the local adata
        if round_number_irls == 0:
            adata.uns["_irls_beta_init"] = beta.copy()

        (
            irls_gene_names,
            design_matrix,
            size_factors,
            counts,
            dispersions,
            beta_genes,
        ) = get_lfc_utils_from_gene_mask_adata(
            adata, irls_mask, beta=beta, disp_param_name=disp_param_name
        )

        # ---- Compute the summands for the beta update and the local nll ---- #

        with parallel_backend(self.joblib_backend):
            res = Parallel(n_jobs=self.num_jobs, verbose=self.joblib_verbosity)(
                delayed(make_irls_update_summands_and_nll_batch)(
                    design_matrix,
                    size_factors,
                    beta_genes[i : i + self.irls_batch_size],
                    dispersions[i : i + self.irls_batch_size],
                    counts[:, i : i + self.irls_batch_size],
                    self.min_mu,
                )
                for i in range(0, len(beta_genes), self.irls_batch_size)
            )

        if len(res) == 0:
            H = np.zeros((0, beta.shape[1], beta.shape[1]))
            y = np.zeros((0, beta.shape[1]))
            local_nll = np.zeros(0)
        else:
            H = np.concatenate([r[0] for r in res])
            y = np.concatenate([r[1] for r in res])
            local_nll = np.concatenate([r[2] for r in res])

        # Create the shared state
        return {
            "beta": beta,
            "local_nll": local_nll,
            "local_hat_matrix": H,
            "local_features": y,
            "irls_gene_names": irls_gene_names,
            "irls_diverged_mask": irls_diverged_mask,
            "irls_mask": irls_mask,
            "global_nll": global_nll,
            "round_number_irls": round_number_irls,
        }

make_local_irls_summands_and_nlls(data_from_opener, shared_state, refit_mode=False)

Make the summands for the IRLS algorithm.

This functions does two main operations:

1) It computes the summands for the beta update. 2) It computes the local quantities to compute the global_nll of the current beta

Parameters:

Name Type Description Default
data_from_opener AnnData

Not used.

required
shared_state dict

A dictionary containing the following keys: - beta: ndarray The current beta, of shape (n_non_zero_genes, n_params). - irls_diverged_mask: ndarray A boolean mask indicating if fed avg should be used for a given gene (shape: (n_non_zero_genes,)). - irls_mask: ndarray A boolean mask indicating if IRLS should be used for a given gene (shape: (n_non_zero_genes,)). - global_nll: ndarray The global_nll of the current beta from the previous beta, of shape (n_non_zero_genes,). - round_number_irls: int The current round number of the IRLS algorithm.

required
refit_mode bool

Whether to run on refit_adatas instead of local_adatas. (default: False).

False

Returns:

Type Description
dict

The state to share to the server. It contains the following fields: - beta: ndarray The current beta, of shape (n_non_zero_genes, n_params). - local_nll: ndarray The local nll of the current beta, of shape (n_irls_genes,). - local_hat_matrix: ndarray The local hat matrix, of shape (n_irls_genes, n_params, n_params). n_irsl_genes is the number of genes that are still active (non zero gene names on the irls_mask). - local_features: ndarray The local features, of shape (n_irls_genes, n_params). - irls_diverged_mask: ndarray A boolean mask indicating if fed avg should be used for a given gene (shape: (n_non_zero_genes,)). - irls_mask: ndarray A boolean mask indicating if IRLS should be used for a given gene (shape: (n_non_zero_genes,)). - global_nll: ndarray The global_nll of the current beta of shape (n_non_zero_genes,). This parameter is simply passed to the next shared state - round_number_irls: int The current round number of the IRLS algorithm. This round number is not updated here.

Source code in fedpydeseq2/core/fed_algorithms/fed_irls/substeps.py
@remote_data
@log_remote_data
@reconstruct_adatas
def make_local_irls_summands_and_nlls(
    self,
    data_from_opener: AnnData,
    shared_state: dict[str, Any],
    refit_mode: bool = False,
):
    """Make the summands for the IRLS algorithm.

    This functions does two main operations:

    1) It computes the summands for the beta update.
    2) It computes the local quantities to compute the global_nll
    of the current beta


    Parameters
    ----------
    data_from_opener : AnnData
        Not used.

    shared_state : dict
        A dictionary containing the following
        keys:
        - beta: ndarray
            The current beta, of shape (n_non_zero_genes, n_params).
        - irls_diverged_mask: ndarray
            A boolean mask indicating if fed avg should be used for a given gene
            (shape: (n_non_zero_genes,)).
        - irls_mask: ndarray
            A boolean mask indicating if IRLS should be used for a given gene
            (shape: (n_non_zero_genes,)).
        - global_nll: ndarray
            The global_nll of the current beta from the previous beta, of shape\
            (n_non_zero_genes,).
        - round_number_irls: int
            The current round number of the IRLS algorithm.

    refit_mode : bool
        Whether to run on `refit_adata`s instead of `local_adata`s.
        (default: False).

    Returns
    -------
    dict
        The state to share to the server.
        It contains the following fields:
        - beta: ndarray
            The current beta, of shape (n_non_zero_genes, n_params).
        - local_nll: ndarray
            The local nll of the current beta, of shape (n_irls_genes,).
        - local_hat_matrix: ndarray
            The local hat matrix, of shape (n_irls_genes, n_params, n_params).
            n_irsl_genes is the number of genes that are still active (non zero
            gene names on the irls_mask).
        - local_features: ndarray
            The local features, of shape (n_irls_genes, n_params).
        - irls_diverged_mask: ndarray
            A boolean mask indicating if fed avg should be used for a given gene
            (shape: (n_non_zero_genes,)).
        - irls_mask: ndarray
            A boolean mask indicating if IRLS should be used for a given gene
            (shape: (n_non_zero_genes,)).
        - global_nll: ndarray
            The global_nll of the current beta of shape
            (n_non_zero_genes,).
            This parameter is simply passed to the next shared state
        - round_number_irls: int
            The current round number of the IRLS algorithm.
            This round number is not updated here.

    """
    if refit_mode:
        adata = self.refit_adata
    else:
        adata = self.local_adata

    # Put all elements in the shared state in readable variables
    beta = shared_state["beta"]
    irls_mask = shared_state["irls_mask"]
    irls_diverged_mask = shared_state["irls_diverged_mask"]
    global_nll = shared_state["global_nll"]
    round_number_irls = shared_state["round_number_irls"]

    # Get the quantitie stored in the adata
    disp_param_name = adata.uns["_irls_disp_param_name"]

    # If this is the first round, save the beta init in a field of the local adata
    if round_number_irls == 0:
        adata.uns["_irls_beta_init"] = beta.copy()

    (
        irls_gene_names,
        design_matrix,
        size_factors,
        counts,
        dispersions,
        beta_genes,
    ) = get_lfc_utils_from_gene_mask_adata(
        adata, irls_mask, beta=beta, disp_param_name=disp_param_name
    )

    # ---- Compute the summands for the beta update and the local nll ---- #

    with parallel_backend(self.joblib_backend):
        res = Parallel(n_jobs=self.num_jobs, verbose=self.joblib_verbosity)(
            delayed(make_irls_update_summands_and_nll_batch)(
                design_matrix,
                size_factors,
                beta_genes[i : i + self.irls_batch_size],
                dispersions[i : i + self.irls_batch_size],
                counts[:, i : i + self.irls_batch_size],
                self.min_mu,
            )
            for i in range(0, len(beta_genes), self.irls_batch_size)
        )

    if len(res) == 0:
        H = np.zeros((0, beta.shape[1], beta.shape[1]))
        y = np.zeros((0, beta.shape[1]))
        local_nll = np.zeros(0)
    else:
        H = np.concatenate([r[0] for r in res])
        y = np.concatenate([r[1] for r in res])
        local_nll = np.concatenate([r[2] for r in res])

    # Create the shared state
    return {
        "beta": beta,
        "local_nll": local_nll,
        "local_hat_matrix": H,
        "local_features": y,
        "irls_gene_names": irls_gene_names,
        "irls_diverged_mask": irls_diverged_mask,
        "irls_mask": irls_mask,
        "global_nll": global_nll,
        "round_number_irls": round_number_irls,
    }

utils

Module to implement the utilities of the IRLS algorithm.

Most of these functions have the _batch suffix, which means that they are vectorized to work over batches of genes in the parralel_backend file in the same module.

make_irls_update_summands_and_nll_batch(design_matrix, size_factors, beta, dispersions, counts, min_mu)

Make the summands for the IRLS algorithm for a given set of genes.

Parameters:

Name Type Description Default
design_matrix ndarray

The design matrix, of shape (n_obs, n_params).

required
size_factors ndarray

The size factors, of shape (n_obs).

required
beta ndarray

The log fold change matrix, of shape (batch_size, n_params).

required
dispersions ndarray

The dispersions, of shape (batch_size).

required
counts ndarray

The counts, of shape (n_obs,batch_size).

required
min_mu float

Lower bound on estimated means, to ensure numerical stability.

required

Returns:

Name Type Description
H ndarray

The H matrix, of shape (batch_size, n_params, n_params).

y ndarray

The y vector, of shape (batch_size, n_params).

nll ndarray

The negative binomial negative log-likelihood, of shape (batch_size).

Source code in fedpydeseq2/core/fed_algorithms/fed_irls/utils.py
def make_irls_update_summands_and_nll_batch(
    design_matrix: np.ndarray,
    size_factors: np.ndarray,
    beta: np.ndarray,
    dispersions: np.ndarray,
    counts: np.ndarray,
    min_mu: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Make the summands for the IRLS algorithm for a given set of genes.

    Parameters
    ----------
    design_matrix : ndarray
        The design matrix, of shape (n_obs, n_params).
    size_factors : ndarray
        The size factors, of shape (n_obs).
    beta : ndarray
        The log fold change matrix, of shape (batch_size, n_params).
    dispersions : ndarray
        The dispersions, of shape (batch_size).
    counts : ndarray
        The counts, of shape (n_obs,batch_size).
    min_mu : float
        Lower bound on estimated means, to ensure numerical stability.

    Returns
    -------
    H : ndarray
        The H matrix, of shape (batch_size, n_params, n_params).
    y : ndarray
        The y vector, of shape (batch_size, n_params).
    nll : ndarray
        The negative binomial negative log-likelihood, of shape (batch_size).
    """
    max_limit = np.log(1e100)
    design_matrix_time_beta_T = design_matrix @ beta.T
    mask_nan = design_matrix_time_beta_T > max_limit

    # In order to avoid overflow and np.inf, we replace all big values in the
    # design_matrix_time_beta_T with 0., then we carry the computation normally, and
    # we modify the final quantity with their true value for the inputs were
    # exp_design_matrix_time_beta_T should have taken values >> 1
    exp_design_matrix_time_beta_T = np.zeros(
        design_matrix_time_beta_T.shape, dtype=design_matrix_time_beta_T.dtype
    )
    exp_design_matrix_time_beta_T[~mask_nan] = np.exp(
        design_matrix_time_beta_T[~mask_nan]
    )
    mu = size_factors[:, None] * exp_design_matrix_time_beta_T

    mu = np.maximum(mu, min_mu)

    W = mu / (1.0 + mu * dispersions[None, :])

    dispersions_broadcast = np.broadcast_to(
        dispersions, (mu.shape[0], dispersions.shape[0])
    )
    W[mask_nan] = 1.0 / dispersions_broadcast[mask_nan]

    z = np.log(mu / size_factors[:, None]) + (counts - mu) / mu
    z[mask_nan] = design_matrix_time_beta_T[mask_nan] - 1.0

    H = (design_matrix.T[:, :, None] * W).transpose(2, 0, 1) @ design_matrix[None, :, :]
    y = (design_matrix.T @ (W * z)).T

    mu[mask_nan] = np.inf
    nll = grid_nb_nll(counts, mu, dispersions, mask_nan)

    return H, y, nll