Skip to content

Core

aggregation

Aggregation functions.

Copy-pasted from the CancerLINQ repo.

aggregate_means(local_means, n_local_samples, filter_nan=False)

Aggregate local means.

Aggregate the local means into a global mean by using the local number of samples.

Parameters:

Name Type Description Default
local_means list[Any]

list of local means. Could be array, float, Series.

required
n_local_samples list[int]

list of number of samples used for each local mean.

required
filter_nan bool

Filter NaN values in the local means, by default False.

False

Returns:

Type Description
Any

Aggregated mean. Same type of the local means

Source code in fedpydeseq2/core/utils/aggregation.py
def aggregate_means(
    local_means: list[Any], n_local_samples: list[int], filter_nan: bool = False
):
    """Aggregate local means.

    Aggregate the local means into a global mean by using the local number of samples.

    Parameters
    ----------
    local_means : list[Any]
        list of local means. Could be array, float, Series.
    n_local_samples : list[int]
        list of number of samples used for each local mean.
    filter_nan : bool, optional
        Filter NaN values in the local means, by default False.

    Returns
    -------
    Any
        Aggregated mean. Same type of the local means
    """
    tot_samples = 0
    tot_mean = np.zeros_like(local_means[0])
    for mean, n_sample in zip(local_means, n_local_samples, strict=False):
        if filter_nan:
            mean = np.nan_to_num(mean, nan=0, copy=False)
        tot_mean += mean * n_sample
        tot_samples += n_sample

    return tot_mean / tot_samples

compute_lfc_utils

get_lfc_utils_from_gene_mask_adata(adata, gene_mask, disp_param_name, beta=None, lfc_param_name=None)

Get the necessary data for LFC computations from the local adata and genes.

Parameters:

Name Type Description Default
adata AnnData

The local AnnData object.

required
gene_mask ndarray

The mask of genes to use for the IRLS algorithm. This mask identifies the genes in the non_zero_gene_names. If None, all non zero genes are used.

required
disp_param_name str

The name of the dispersion parameter in the adata.varm.

required
beta Optional[ndarray]

The log fold change values, of shape (n_non_zero_genes,).

None
lfc_param_name str | None

The name of the lfc parameter in the adata.varm. Is incompatible with beta.

None

Returns:

Name Type Description
gene_names list[str]

The names of the genes to use for the IRLS algorithm.

design_matrix ndarray

The design matrix.

size_factors ndarray

The size factors.

counts ndarray

The count matrix from the local adata.

dispersions ndarray

The dispersions from the local adata.

beta_on_mask ndarray

The log fold change values on the mask.

Source code in fedpydeseq2/core/utils/compute_lfc_utils.py
def get_lfc_utils_from_gene_mask_adata(
    adata: ad.AnnData,
    gene_mask: np.ndarray | None,
    disp_param_name: str,
    beta: np.ndarray | None = None,
    lfc_param_name: str | None = None,
) -> tuple[list[str], np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Get the necessary data for LFC computations from the local adata and genes.

    Parameters
    ----------
    adata : ad.AnnData
        The local AnnData object.

    gene_mask : np.ndarray, optional
        The mask of genes to use for the IRLS algorithm.
        This mask identifies the genes in the non_zero_gene_names.
        If None, all non zero genes are used.

    disp_param_name : str
        The name of the dispersion parameter in the adata.varm.

    beta : Optional[np.ndarray]
        The log fold change values, of shape (n_non_zero_genes,).

    lfc_param_name: Optional[str]
        The name of the lfc parameter in the adata.varm.
        Is incompatible with beta.

    Returns
    -------
    gene_names : list[str]
        The names of the genes to use for the IRLS algorithm.
    design_matrix : np.ndarray
        The design matrix.
    size_factors : np.ndarray
        The size factors.
    counts : np.ndarray
        The count matrix from the local adata.
    dispersions : np.ndarray
        The dispersions from the local adata.
    beta_on_mask : np.ndarray
        The log fold change values on the mask.
    """
    # Check that one of beta or lfc_param_name is not None
    assert (beta is not None) ^ (
        lfc_param_name is not None
    ), "One of beta or lfc_param_name must be not None"

    # Get non zero genes
    non_zero_genes_names = adata.var_names[adata.varm["non_zero"]]

    # Get the irls genes
    if gene_mask is None:
        gene_names = non_zero_genes_names
    else:
        gene_names = non_zero_genes_names[gene_mask]

    # Get beta
    if lfc_param_name is not None:
        beta_on_mask = adata[:, gene_names].varm[lfc_param_name].to_numpy()
    elif gene_mask is not None:
        assert beta is not None  # for mypy
        beta_on_mask = beta[gene_mask]
    else:
        assert beta is not None  # for mypy
        beta_on_mask = beta.copy()

    design_matrix = adata.obsm["design_matrix"].values
    size_factors = adata.obsm["size_factors"]
    counts = adata[:, gene_names].X
    dispersions = adata[:, gene_names].varm[disp_param_name]

    return gene_names, design_matrix, size_factors, counts, dispersions, beta_on_mask

design_matrix

build_design_matrix(metadata, design_factors='stage', levels=None, continuous_factors=None, ref_levels=None)

Build design_matrix matrix for DEA.

Unless specified, the reference factor is chosen alphabetically. Copied from PyDESeq2, with some modifications specific to fedomics to ensure that all centers have the same columns

Parameters:

Name Type Description Default
metadata DataFrame

DataFrame containing metadata information. Must be indexed by sample barcodes.

required
design_factors str or list

Name of the columns of metadata to be used as design_matrix variables. (default: "condition").

'stage'
levels dict

An optional dictionary of lists of strings specifying the levels of each factor in the global design, e.g. {"condition": ["A", "B"]}. (default: None).

None
ref_levels dict

An optional dictionary of the form {"factor": "test_level"} specifying for each factor the reference (control) level against which we're testing, e.g. {"condition", "A"}. Factors that are left out will be assigned random reference levels. (default: None).

None
continuous_factors list

An optional list of continuous (as opposed to categorical) factors, that should also be in design_factors. Any factor in design_factors but not in continuous_factors will be considered categorical (default: None).

None

Returns:

Type Description
DataFrame

A DataFrame with experiment design information (to split cohorts). Indexed by sample barcodes.

Source code in fedpydeseq2/core/utils/design_matrix.py
def build_design_matrix(
    metadata: pd.DataFrame,
    design_factors: str | list[str] = "stage",
    levels: dict[str, list[str]] | None = None,
    continuous_factors: list[str] | None = None,
    ref_levels: dict[str, str] | None = None,
) -> pd.DataFrame:
    """Build design_matrix matrix for DEA.

    Unless specified, the reference factor is chosen alphabetically.
    Copied from PyDESeq2, with some modifications specific to fedomics to ensure that
    all centers have the same columns

    Parameters
    ----------
    metadata : pandas.DataFrame
        DataFrame containing metadata information.
        Must be indexed by sample barcodes.

    design_factors : str or list
        Name of the columns of metadata to be used as design_matrix variables.
        (default: ``"condition"``).

    levels : dict, optional
        An optional dictionary of lists of strings specifying the levels of each factor
        in the global design, e.g. ``{"condition": ["A", "B"]}``. (default: ``None``).

    ref_levels : dict, optional
        An optional dictionary of the form ``{"factor": "test_level"}``
        specifying for each factor the reference (control) level against which
        we're testing, e.g. ``{"condition", "A"}``. Factors that are left out
        will be assigned random reference levels. (default: ``None``).

    continuous_factors : list, optional
        An optional list of continuous (as opposed to categorical) factors, that should
        also be in ``design_factors``. Any factor in ``design_factors`` but not in
        ``continuous_factors`` will be considered categorical (default: ``None``).

    Returns
    -------
    pandas.DataFrame
        A DataFrame with experiment design information (to split cohorts).
        Indexed by sample barcodes.
    """
    if isinstance(
        design_factors, str
    ):  # if there is a single factor, convert to singleton list
        design_factors = [design_factors]

    # Check that factors in the design don't contain underscores. If so, convert
    # them to hyphens
    if np.any(["_" in factor for factor in design_factors]):
        warnings.warn(
            """Same factor names in the design contain underscores ('_'). They will
            be converted to hyphens ('-').""",
            UserWarning,
            stacklevel=2,
        )
        design_factors = [factor.replace("_", "-") for factor in design_factors]

    # Check that level factors in the design don't contain underscores. If so, convert
    # them to hyphens
    warning_issued = False
    for factor in design_factors:
        if ptypes.is_numeric_dtype(metadata[factor]):
            continue
        if np.any(["_" in value for value in metadata[factor]]):
            if not warning_issued:
                warnings.warn(
                    """Some factor levels in the design contain underscores ('_').
                    They will be converted to hyphens ('-').""",
                    UserWarning,
                    stacklevel=2,
                )
                warning_issued = True
            metadata[factor] = metadata[factor].apply(lambda x: x.replace("_", "-"))

    if continuous_factors is not None:
        for factor in continuous_factors:
            if factor not in design_factors:
                raise ValueError(
                    f"Continuous factor '{factor}' not in design factors: "
                    f"{design_factors}."
                )
        categorical_factors = [
            factor for factor in design_factors if factor not in continuous_factors
        ]
    else:
        categorical_factors = design_factors

    if levels is None:
        levels = {factor: np.unique(metadata[factor]) for factor in categorical_factors}

    # Check that there is at least one categorical factor
    if len(categorical_factors) > 0:
        design_matrix = pd.get_dummies(metadata[categorical_factors], drop_first=False)
        # Check if there missing levels. If so, add them and set to 0.
        for factor in categorical_factors:
            for level in levels[factor]:
                if f"{factor}_{level}" not in design_matrix.columns:
                    design_matrix[f"{factor}_{level}"] = 0

        # Pick the first level as reference. Then, drop the column.
        for factor in categorical_factors:
            if ref_levels is not None and factor in ref_levels:
                ref = ref_levels[factor]
            else:
                ref = levels[factor][0]

            ref_level_name = f"{factor}_{ref}"
            design_matrix.drop(ref_level_name, axis="columns", inplace=True)

            # Add reference level as column name suffix
            design_matrix.columns = [
                f"{col}_vs_{ref}" if col.startswith(factor) else col
                for col in design_matrix.columns
            ]
    else:
        # There is no categorical factor in the design
        design_matrix = pd.DataFrame(index=metadata.index)

    # Add the intercept column
    design_matrix.insert(0, "intercept", 1)

    # Convert categorical factors one-hot encodings to int
    design_matrix = design_matrix.astype("int")

    # Add continuous factors
    if continuous_factors is not None:
        for factor in continuous_factors:
            # This factor should be numeric
            design_matrix[factor] = pd.to_numeric(metadata[factor])
    return design_matrix

layers

build_layers

Module to construct the layers.

cooks

Module to set the cooks layer.

can_set_cooks_layer(adata, shared_state, raise_error=False)

Check if the Cook's distance can be set.

Parameters:

Name Type Description Default
adata AnnData

The local adata.

required
shared_state Optional[dict]

The shared state containing the Cook's dispersion values.

required
raise_error bool

Whether to raise an error if the Cook's distance cannot be set.

False

Returns:

Name Type Description
bool bool

Whether the Cook's distance can be set.

Raises:

Type Description
ValueError:

If the Cook's distance cannot be set and raise_error is True.

Source code in fedpydeseq2/core/utils/layers/build_layers/cooks.py
def can_set_cooks_layer(
    adata: ad.AnnData, shared_state: dict | None, raise_error: bool = False
) -> bool:
    """Check if the Cook's distance can be set.

    Parameters
    ----------
    adata : ad.AnnData
        The local adata.

    shared_state : Optional[dict]
        The shared state containing the Cook's dispersion values.

    raise_error : bool
        Whether to raise an error if the Cook's distance cannot be set.

    Returns
    -------
    bool:
        Whether the Cook's distance can be set.

    Raises
    ------
    ValueError:
        If the Cook's distance cannot be set and raise_error is True.

    """
    if "cooks" in adata.layers.keys():
        return True
    if shared_state is None:
        if raise_error:
            raise ValueError(
                "To set cooks layer, there should be " "an input shared state"
            )
        else:
            return False
    has_non_zero = "non_zero" in adata.varm.keys()
    try:
        has_hat_diagonals = can_set_hat_diagonals_layer(
            adata, shared_state, raise_error
        )
    except ValueError as hat_diagonals_error:
        raise ValueError(
            "The Cook's distance cannot be set because the hat diagonals cannot be set."
        ) from hat_diagonals_error
    try:
        has_mu_LFC = can_set_mu_layer(
            local_adata=adata,
            lfc_param_name="LFC",
            mu_param_name="_mu_LFC",
        )
    except ValueError as mu_LFC_error:
        raise ValueError(
            "The Cook's distance cannot be set because the mu_LFC layer cannot be set."
        ) from mu_LFC_error
    has_X = adata.X is not None
    has_cooks_dispersions = "cooks_dispersions" in shared_state.keys()
    has_all = (
        has_non_zero
        and has_hat_diagonals
        and has_mu_LFC
        and has_X
        and has_cooks_dispersions
    )
    if not has_all and raise_error:
        raise ValueError(
            "The Cook's distance cannot be set because "
            "the following conditions are not met:"
            f"\n- has_non_zero: {has_non_zero}"
            f"\n- has_hat_diagonals: {has_hat_diagonals}"
            f"\n- has_mu_LFC: {has_mu_LFC}"
            f"\n- has_X: {has_X}"
            f"\n- has_cooks_dispersions: {has_cooks_dispersions}"
        )
    return has_all
set_cooks_layer(adata, shared_state)

Compute the Cook's distance from the shared state.

This function computes the Cook's distance from the shared state and stores it in the "cooks" layer of the local adata.

Parameters:

Name Type Description Default
adata AnnData

The local adata.

required
shared_state dict

The shared state containing the Cook's dispersion values.

required
Source code in fedpydeseq2/core/utils/layers/build_layers/cooks.py
def set_cooks_layer(
    adata: ad.AnnData,
    shared_state: dict | None,
):
    """Compute the Cook's distance from the shared state.

    This function computes the Cook's distance from the shared state and stores it
    in the "cooks" layer of the local adata.

    Parameters
    ----------
    adata : ad.AnnData
        The local adata.

    shared_state : dict
        The shared state containing the Cook's dispersion values.

    """
    can_set_cooks_layer(adata, shared_state, raise_error=True)
    if "cooks" in adata.layers.keys():
        return
    # set all necessary layers
    assert isinstance(shared_state, dict)
    set_mu_layer(adata, lfc_param_name="LFC", mu_param_name="_mu_LFC")
    set_hat_diagonals_layer(adata, shared_state)
    num_vars = adata.uns["n_params"]
    cooks_dispersions = shared_state["cooks_dispersions"]
    V = (
        adata[:, adata.varm["non_zero"]].layers["_mu_LFC"]
        + cooks_dispersions[None, adata.varm["non_zero"]]
        * adata[:, adata.varm["non_zero"]].layers["_mu_LFC"] ** 2
    )
    squared_pearson_res = (
        adata[:, adata.varm["non_zero"]].X
        - adata[:, adata.varm["non_zero"]].layers["_mu_LFC"]
    ) ** 2 / V
    diag_mul = (
        adata[:, adata.varm["non_zero"]].layers["_hat_diagonals"]
        / (1 - adata[:, adata.varm["non_zero"]].layers["_hat_diagonals"]) ** 2
    )
    adata.layers["cooks"] = np.full((adata.n_obs, adata.n_vars), np.NaN)
    adata.layers["cooks"][:, adata.varm["non_zero"]] = (
        squared_pearson_res / num_vars * diag_mul
    )

fit_lin_mu_hat

Module to reconstruct the fit_lin_mu_hat layer.

can_get_fit_lin_mu_hat(local_adata, raise_error=False)

Check if the fit_lin_mu_hat layer can be reconstructed.

Parameters:

Name Type Description Default
local_adata AnnData

The local AnnData object.

required
raise_error bool

If True, raise an error if the fit_lin_mu_hat layer cannot be reconstructed.

False

Returns:

Type Description
bool

True if the fit_lin_mu_hat layer can be reconstructed, False otherwise.

Raises:

Type Description
ValueError

If the fit_lin_mu_hat layer cannot be reconstructed and raise_error is True.

Source code in fedpydeseq2/core/utils/layers/build_layers/fit_lin_mu_hat.py
def can_get_fit_lin_mu_hat(local_adata: ad.AnnData, raise_error: bool = False) -> bool:
    """Check if the fit_lin_mu_hat layer can be reconstructed.

    Parameters
    ----------
    local_adata : ad.AnnData
        The local AnnData object.

    raise_error : bool, optional
        If True, raise an error if the fit_lin_mu_hat layer cannot be reconstructed.

    Returns
    -------
    bool
        True if the fit_lin_mu_hat layer can be reconstructed, False otherwise.

    Raises
    ------
    ValueError
        If the fit_lin_mu_hat layer cannot be reconstructed and raise_error is True.

    """
    if "_fit_lin_mu_hat" in local_adata.layers.keys():
        return True
    try:
        y_hat_ok = can_get_y_hat(local_adata, raise_error=raise_error)
    except ValueError as y_hat_error:
        raise ValueError(
            f"Error while checking if y_hat can be reconstructed: {y_hat_error}"
        ) from y_hat_error

    has_size_factors = "size_factors" in local_adata.obsm.keys()
    has_non_zero = "non_zero" in local_adata.varm.keys()
    if not has_size_factors or not has_non_zero:
        if raise_error:
            raise ValueError(
                "Local adata must contain the size_factors obsm "
                "and the non_zero varm to compute the fit_lin_mu_hat layer."
                " Here are the keys present in the local adata: "
                f"obsm : {local_adata.obsm.keys()} and varm : {local_adata.varm.keys()}"
            )
        return False
    return y_hat_ok
set_fit_lin_mu_hat(local_adata, min_mu=0.5)

Calculate the _fit_lin_mu_hat layer using the provided local data.

Checks are performed to ensure necessary keys are present in the data.

Parameters:

Name Type Description Default
local_adata AnnData

The local anndata object containing necessary keys for computation.

required
min_mu float

The minimum value for mu, defaults to 0.5.

0.5
Source code in fedpydeseq2/core/utils/layers/build_layers/fit_lin_mu_hat.py
def set_fit_lin_mu_hat(local_adata: ad.AnnData, min_mu: float = 0.5):
    """
    Calculate the _fit_lin_mu_hat layer using the provided local data.

    Checks are performed to ensure necessary keys are present in the data.

    Parameters
    ----------
    local_adata : ad.AnnData
        The local anndata object containing necessary keys for computation.
    min_mu : float, optional
        The minimum value for mu, defaults to 0.5.

    """
    can_get_fit_lin_mu_hat(local_adata, raise_error=True)
    if "_fit_lin_mu_hat" in local_adata.layers.keys():
        return
    set_y_hat(local_adata)
    mu_hat = local_adata.obsm["size_factors"][:, None] * local_adata.layers["_y_hat"]
    fit_lin_mu_hat = np.maximum(mu_hat, min_mu)

    fit_lin_mu_hat[:, ~local_adata.varm["non_zero"]] = np.nan
    local_adata.layers["_fit_lin_mu_hat"] = fit_lin_mu_hat

hat_diagonals

Module to set the hat diagonals layer.

can_set_hat_diagonals_layer(adata, shared_state, raise_error=False)

Check if the hat diagonals layer can be reconstructed.

Parameters:

Name Type Description Default
adata AnnData

The AnnData object.

required
shared_state Optional[dict]

The shared state dictionary.

required
raise_error bool

If True, raise an error if the hat diagonals layer cannot be reconstructed.

False

Returns:

Type Description
bool

True if the hat diagonals layer can be reconstructed, False otherwise.

Raises:

Type Description
ValueError

If the hat diagonals layer cannot be reconstructed and raise_error is True.

Source code in fedpydeseq2/core/utils/layers/build_layers/hat_diagonals.py
def can_set_hat_diagonals_layer(
    adata: ad.AnnData, shared_state: dict | None, raise_error: bool = False
) -> bool:
    """Check if the hat diagonals layer can be reconstructed.

    Parameters
    ----------
    adata : ad.AnnData
        The AnnData object.

    shared_state : Optional[dict]
        The shared state dictionary.

    raise_error : bool, optional
        If True, raise an error if the hat diagonals layer cannot be reconstructed.

    Returns
    -------
    bool
        True if the hat diagonals layer can be reconstructed, False otherwise.

    Raises
    ------
    ValueError
        If the hat diagonals layer cannot be reconstructed and raise_error is True.

    """
    if "_hat_diagonals" in adata.layers.keys():
        return True

    if shared_state is None:
        if raise_error:
            raise ValueError(
                "To set the _hat_diagonals layer, there" "should be a shared state."
            )
        else:
            return False

    has_design_matrix = "design_matrix" in adata.obsm.keys()
    has_lfc_param = "LFC" in adata.varm.keys()
    has_size_factors = "size_factors" in adata.obsm.keys()
    has_non_zero = "non_zero" in adata.varm.keys()
    has_dispersion = "dispersions" in adata.varm.keys()
    has_global_hat_matrix_inv = "global_hat_matrix_inv" in shared_state.keys()

    has_all = (
        has_design_matrix
        and has_lfc_param
        and has_size_factors
        and has_non_zero
        and has_global_hat_matrix_inv
        and has_dispersion
    )
    if not has_all:
        if raise_error:
            raise ValueError(
                "Adata must contain the design matrix obsm"
                ", the LFC varm, the dispersions varm, "
                "the size_factors obsm, the non_zero varm "
                "and the global_hat_matrix_inv "
                "in the shared state to compute the hat diagonals layer."
                " Here are the keys present in the adata: "
                f"obsm : {adata.obsm.keys()} and varm : {adata.varm.keys()}, and the "
                f"shared state keys: {shared_state.keys()}"
            )
        return False
    return True
make_hat_diag_batch(beta, global_hat_matrix_inv, design_matrix, size_factors, dispersions, min_mu=0.5)

Compute the H matrix for a batch of LFC estimates.

Parameters:

Name Type Description Default
beta ndarray

Current LFC estimate, of shape (batch_size, n_params).

required
global_hat_matrix_inv ndarray

The inverse of the global hat matrix, of shape (batch_size, n_params, n_params).

required
design_matrix ndarray

The design matrix, of shape (n_obs, n_params).

required
size_factors ndarray

The size factors, of shape (n_obs).

required
dispersions ndarray

The dispersions, of shape (batch_size).

required
min_mu float

Lower bound on estimated means, to ensure numerical stability. (default: 0.5).

0.5

Returns:

Type Description
ndarray

The H matrix, of shape (batch_size, n_obs).

Source code in fedpydeseq2/core/utils/layers/build_layers/hat_diagonals.py
def make_hat_diag_batch(
    beta: np.ndarray,
    global_hat_matrix_inv: np.ndarray,
    design_matrix: np.ndarray,
    size_factors: np.ndarray,
    dispersions: np.ndarray,
    min_mu: float = 0.5,
) -> np.ndarray:
    """
    Compute the H matrix for a batch of LFC estimates.

    Parameters
    ----------
    beta : np.ndarray
        Current LFC estimate, of shape (batch_size, n_params).
    global_hat_matrix_inv : np.ndarray
        The inverse of the global hat matrix, of shape (batch_size, n_params, n_params).
    design_matrix : np.ndarray
        The design matrix, of shape (n_obs, n_params).
    size_factors : np.ndarray
        The size factors, of shape (n_obs).
    dispersions : np.ndarray
        The dispersions, of shape (batch_size).
    min_mu : float
        Lower bound on estimated means, to ensure numerical stability.
        (default: ``0.5``).

    Returns
    -------
    np.ndarray
        The H matrix, of shape (batch_size, n_obs).

    """
    mu = size_factors[:, None] * np.exp(design_matrix @ beta.T)
    mu_clipped = np.maximum(
        mu,
        min_mu,
    )

    # W of shape (n_obs, batch_size)
    W = mu_clipped / (1.0 + mu_clipped * dispersions[None, :])

    # W_sq Of shape (batch_size, n_obs)
    W_sq = np.sqrt(W).T

    # Inside the diagonal operator is of shape (batch_size, n_obs, n_obs)
    # The diagonal operator takes the diagonal per gene in the batch
    # H is therefore of shape (batch_size, n_obs)
    H = np.diagonal(
        design_matrix @ global_hat_matrix_inv @ design_matrix.T,
        axis1=1,
        axis2=2,
    )

    H = W_sq * H * W_sq

    return H
set_hat_diagonals_layer(adata, shared_state, n_jobs=1, joblib_verbosity=0, joblib_backend='loky', batch_size=100, min_mu=0.5)

Compute the hat diagonals layer from the adata and the shared state.

Parameters:

Name Type Description Default
adata AnnData

The AnnData object.

required
shared_state Optional[dict]

The shared state dictionary. This dictionary must contain the global hat matrix inverse.

required
n_jobs int

The number of jobs to use for parallel processing.

1
joblib_verbosity int

The verbosity level of joblib.

0
joblib_backend str

The joblib backend to use.

'loky'
batch_size int

The batch size for parallel processing.

100
min_mu float

Lower bound on estimated means, to ensure numerical stability.

0.5

Returns:

Type Description
ndarray

The hat diagonals layer, of shape (n_obs, n_params).

Source code in fedpydeseq2/core/utils/layers/build_layers/hat_diagonals.py
def set_hat_diagonals_layer(
    adata: ad.AnnData,
    shared_state: dict | None,
    n_jobs: int = 1,
    joblib_verbosity: int = 0,
    joblib_backend: str = "loky",
    batch_size: int = 100,
    min_mu: float = 0.5,
):
    """
    Compute the hat diagonals layer from the adata and the shared state.

    Parameters
    ----------
    adata : ad.AnnData
        The AnnData object.

    shared_state : Optional[dict]
        The shared state dictionary.
        This dictionary must contain the global hat matrix inverse.

    n_jobs : int
        The number of jobs to use for parallel processing.

    joblib_verbosity : int
        The verbosity level of joblib.

    joblib_backend : str
        The joblib backend to use.

    batch_size : int
        The batch size for parallel processing.

    min_mu : float
        Lower bound on estimated means, to ensure numerical stability.

    Returns
    -------
    np.ndarray
        The hat diagonals layer, of shape (n_obs, n_params).

    """
    can_set_hat_diagonals_layer(adata, shared_state, raise_error=True)
    if "_hat_diagonals" in adata.layers.keys():
        return

    assert shared_state is not None, (
        "To construct the _hat_diagonals layer, " "one must have a shared state."
    )

    gene_names = adata.var_names[adata.varm["non_zero"]]
    beta = adata.varm["LFC"].loc[gene_names].to_numpy()
    design_matrix = adata.obsm["design_matrix"].values
    size_factors = adata.obsm["size_factors"]

    dispersions = adata[:, gene_names].varm["dispersions"]

    # ---- Step 1: Compute the mu and the diagonal of the hat matrix ---- #

    with parallel_backend(joblib_backend):
        res = Parallel(n_jobs=n_jobs, verbose=joblib_verbosity)(
            delayed(make_hat_diag_batch)(
                beta[i : i + batch_size],
                shared_state["global_hat_matrix_inv"][i : i + batch_size],
                design_matrix,
                size_factors,
                dispersions[i : i + batch_size],
                min_mu,
            )
            for i in range(0, len(beta), batch_size)
        )

    H = np.concatenate(res)

    H_layer = np.full(adata.shape, np.NaN)

    H_layer[:, adata.var_names.get_indexer(gene_names)] = H.T

    adata.layers["_hat_diagonals"] = H_layer

mu_hat

Module to build the mu_hat layer.

can_get_mu_hat(local_adata, raise_error=False)

Check if the mu_hat layer can be reconstructed.

Parameters:

Name Type Description Default
local_adata AnnData

The local AnnData object.

required
raise_error bool

If True, raise an error if the mu_hat layer cannot be reconstructed.

False

Returns:

Type Description
bool

True if the mu_hat layer can be reconstructed, False otherwise.

Raises:

Type Description
ValueError

If the mu_hat layer cannot be reconstructed and raise_error is True.

Source code in fedpydeseq2/core/utils/layers/build_layers/mu_hat.py
def can_get_mu_hat(local_adata: ad.AnnData, raise_error: bool = False) -> bool:
    """Check if the mu_hat layer can be reconstructed.

    Parameters
    ----------
    local_adata : ad.AnnData
        The local AnnData object.

    raise_error : bool, optional
        If True, raise an error if the mu_hat layer cannot be reconstructed.

    Returns
    -------
    bool
        True if the mu_hat layer can be reconstructed, False otherwise.

    Raises
    ------
    ValueError
        If the mu_hat layer cannot be reconstructed and raise_error is True.

    """
    if "_mu_hat" in local_adata.layers.keys():
        return True
    has_num_replicates = "num_replicates" in local_adata.uns
    has_n_params = "n_params" in local_adata.uns
    if not has_num_replicates or not has_n_params:
        if raise_error:
            raise ValueError(
                "Local adata must contain num_replicates in uns field "
                "and n_params in uns field to compute mu_hat."
                " Here are the keys present in the local adata: "
                f"uns : {local_adata.uns.keys()}"
            )
        return False
    # If the number of replicates is not equal to the number of parameters,
    # we need to reconstruct mu_hat from the adata.
    if len(local_adata.uns["num_replicates"]) != local_adata.uns["n_params"]:
        try:
            mu_hat_LFC_ok = can_set_mu_layer(
                local_adata=local_adata,
                lfc_param_name="_mu_hat_LFC",
                mu_param_name="_irls_mu_hat",
                raise_error=raise_error,
            )
        except ValueError as mu_hat_LFC_error:
            raise ValueError(
                "Error while checking if mu_hat_LFC can "
                f"be reconstructed: {mu_hat_LFC_error}"
            ) from mu_hat_LFC_error
        return mu_hat_LFC_ok
    else:
        try:
            fit_lin_mu_hat_ok = can_get_fit_lin_mu_hat(
                local_adata=local_adata,
                raise_error=raise_error,
            )
        except ValueError as fit_lin_mu_hat_error:
            raise ValueError(
                "Error while checking if fit_lin_mu_hat can be "
                f"reconstructed: {fit_lin_mu_hat_error}"
            ) from fit_lin_mu_hat_error
        return fit_lin_mu_hat_ok
set_mu_hat_layer(local_adata)

Reconstruct the mu_hat layer.

Parameters:

Name Type Description Default
local_adata AnnData

The local AnnData object.

required
Source code in fedpydeseq2/core/utils/layers/build_layers/mu_hat.py
def set_mu_hat_layer(local_adata: ad.AnnData):
    """
    Reconstruct the mu_hat layer.

    Parameters
    ----------
    local_adata: ad.AnnData
        The local AnnData object.

    """
    can_get_mu_hat(local_adata, raise_error=True)
    if "_mu_hat" in local_adata.layers.keys():
        return

    if len(local_adata.uns["num_replicates"]) != local_adata.uns["n_params"]:
        set_mu_layer(
            local_adata=local_adata,
            lfc_param_name="_mu_hat_LFC",
            mu_param_name="_irls_mu_hat",
        )
        local_adata.layers["_mu_hat"] = local_adata.layers["_irls_mu_hat"].copy()
        return
    set_fit_lin_mu_hat(
        local_adata=local_adata,
    )
    local_adata.layers["_mu_hat"] = local_adata.layers["_fit_lin_mu_hat"].copy()

mu_layer

Module to construct mu layer from LFC estimates.

can_set_mu_layer(local_adata, lfc_param_name, mu_param_name, raise_error=False)

Check if the mu layer can be reconstructed.

Parameters:

Name Type Description Default
local_adata AnnData

The local AnnData object.

required
lfc_param_name str

The name of the log fold changes parameter in the adata.

required
mu_param_name str

The name of the mu parameter in the adata.

required
raise_error bool

If True, raise an error if the mu layer cannot be reconstructed.

False

Returns:

Type Description
bool

True if the mu layer can be reconstructed, False otherwise.

Raises:

Type Description
ValueError

If the mu layer cannot be reconstructed and raise_error is True.

Source code in fedpydeseq2/core/utils/layers/build_layers/mu_layer.py
def can_set_mu_layer(
    local_adata: ad.AnnData,
    lfc_param_name: str,
    mu_param_name: str,
    raise_error: bool = False,
) -> bool:
    """Check if the mu layer can be reconstructed.

    Parameters
    ----------
    local_adata : ad.AnnData
        The local AnnData object.

    lfc_param_name : str
        The name of the log fold changes parameter in the adata.

    mu_param_name : str
        The name of the mu parameter in the adata.

    raise_error : bool, optional
        If True, raise an error if the mu layer cannot be reconstructed.

    Returns
    -------
    bool
        True if the mu layer can be reconstructed, False otherwise.

    Raises
    ------
    ValueError
        If the mu layer cannot be reconstructed and raise_error is True.

    """
    if mu_param_name in local_adata.layers.keys():
        return True

    has_design_matrix = "design_matrix" in local_adata.obsm.keys()
    has_lfc_param = lfc_param_name in local_adata.varm.keys()
    has_size_factors = "size_factors" in local_adata.obsm.keys()
    has_non_zero = "non_zero" in local_adata.varm.keys()

    has_all = has_design_matrix and has_lfc_param and has_size_factors and has_non_zero
    if not has_all:
        if raise_error:
            raise ValueError(
                "Local adata must contain the design matrix obsm"
                f", the {lfc_param_name} varm to compute the mu layer, "
                f"the size_factors obsm and the non_zero varm. "
                " Here are the keys present in the local adata: "
                f"obsm : {local_adata.obsm.keys()} and varm : {local_adata.varm.keys()}"
            )
        return False
    return True
make_mu_batch(beta, design_matrix, size_factors)

Compute the mu matrix for a batch of LFC estimates.

Parameters:

Name Type Description Default
beta ndarray

Current LFC estimate, of shape (batch_size, n_params).

required
design_matrix ndarray

The design matrix, of shape (n_obs, n_params).

required
size_factors ndarray

The size factors, of shape (n_obs).

required

Returns:

Name Type Description
mu ndarray

The mu matrix, of shape (n_obs, batch_size).

Source code in fedpydeseq2/core/utils/layers/build_layers/mu_layer.py
def make_mu_batch(
    beta: np.ndarray,
    design_matrix: np.ndarray,
    size_factors: np.ndarray,
) -> np.ndarray:
    """
    Compute the mu matrix for a batch of LFC estimates.

    Parameters
    ----------
    beta : np.ndarray
        Current LFC estimate, of shape (batch_size, n_params).
    design_matrix : np.ndarray
        The design matrix, of shape (n_obs, n_params).
    size_factors : np.ndarray
        The size factors, of shape (n_obs).

    Returns
    -------
    mu : np.ndarray
        The mu matrix, of shape (n_obs, batch_size).

    """
    mu = size_factors[:, None] * np.exp(design_matrix @ beta.T)

    return mu
set_mu_layer(local_adata, lfc_param_name, mu_param_name, n_jobs=1, joblib_verbosity=0, joblib_backend='loky', batch_size=100)

Reconstruct a mu layer from the adata and a given LFC field.

Parameters:

Name Type Description Default
local_adata AnnData

The local AnnData object.

required
lfc_param_name str

The name of the log fold changes parameter in the adata.

required
mu_param_name str

The name of the mu parameter in the adata.

required
n_jobs int

Number of jobs to run in parallel.

1
joblib_verbosity int

Verbosity level of joblib.

0
joblib_backend str

Joblib backend to use.

'loky'
batch_size int

Batch size for parallelization.

100
Source code in fedpydeseq2/core/utils/layers/build_layers/mu_layer.py
def set_mu_layer(
    local_adata: ad.AnnData,
    lfc_param_name: str,
    mu_param_name: str,
    n_jobs: int = 1,
    joblib_verbosity: int = 0,
    joblib_backend: str = "loky",
    batch_size: int = 100,
):
    """Reconstruct a mu layer from the adata and a given LFC field.

    Parameters
    ----------
    local_adata : ad.AnnData
        The local AnnData object.

    lfc_param_name : str
        The name of the log fold changes parameter in the adata.

    mu_param_name : str
        The name of the mu parameter in the adata.

    n_jobs : int
        Number of jobs to run in parallel.

    joblib_verbosity : int
        Verbosity level of joblib.

    joblib_backend : str
        Joblib backend to use.

    batch_size : int
        Batch size for parallelization.

    """
    can_set_mu_layer(
        local_adata, lfc_param_name, mu_param_name=mu_param_name, raise_error=True
    )
    if mu_param_name in local_adata.layers.keys():
        return
    gene_names = local_adata.var_names[local_adata.varm["non_zero"]]
    beta = local_adata.varm[lfc_param_name].loc[gene_names].to_numpy()
    design_matrix = local_adata.obsm["design_matrix"].values
    size_factors = local_adata.obsm["size_factors"]

    with parallel_backend(joblib_backend):
        res = Parallel(n_jobs=n_jobs, verbose=joblib_verbosity)(
            delayed(make_mu_batch)(
                beta[i : i + batch_size],
                design_matrix,
                size_factors,
            )
            for i in range(0, len(beta), batch_size)
        )

    if len(res) == 0:
        mu = np.zeros((local_adata.shape[0], 0))
    else:
        mu = np.concatenate(list(res), axis=1)

    mu_layer = np.full(local_adata.shape, np.NaN)

    mu_layer[:, local_adata.var_names.get_indexer(gene_names)] = mu

    local_adata.layers[mu_param_name] = mu_layer

normed_counts

Module to construct the normed_counts layer.

can_get_normed_counts(adata, raise_error=False)

Check if the normed_counts layer can be reconstructed.

Parameters:

Name Type Description Default
adata AnnData

The local AnnData object.

required
raise_error bool

If True, raise an error if the normed_counts layer cannot be reconstructed.

False

Returns:

Type Description
bool

True if the normed_counts layer can be reconstructed, False otherwise.

Raises:

Type Description
ValueError

If the normed_counts layer cannot be reconstructed and raise_error is True.

Source code in fedpydeseq2/core/utils/layers/build_layers/normed_counts.py
def can_get_normed_counts(adata: ad.AnnData, raise_error: bool = False) -> bool:
    """Check if the normed_counts layer can be reconstructed.

    Parameters
    ----------
    adata : ad.AnnData
        The local AnnData object.

    raise_error : bool, optional
        If True, raise an error if the normed_counts layer cannot be reconstructed.

    Returns
    -------
    bool
        True if the normed_counts layer can be reconstructed, False otherwise.

    Raises
    ------
    ValueError
        If the normed_counts layer cannot be reconstructed and raise_error is True.

    """
    if "normed_counts" in adata.layers.keys():
        return True
    has_X = adata.X is not None
    has_size_factors = "size_factors" in adata.obsm.keys()
    if not has_X or not has_size_factors:
        if raise_error:
            raise ValueError(
                "Local adata must contain the X field "
                "and the size_factors obsm to compute the normed_counts layer."
                " Here are the keys present in the adata: "
                f" obsm : {adata.obsm.keys()}"
            )
        return False
    return True
set_normed_counts(adata)

Reconstruct the normed_counts layer.

Parameters:

Name Type Description Default
adata AnnData

The local AnnData object.

required
Source code in fedpydeseq2/core/utils/layers/build_layers/normed_counts.py
def set_normed_counts(adata: ad.AnnData):
    """Reconstruct the normed_counts layer.

    Parameters
    ----------
    adata : ad.AnnData
        The local AnnData object.

    """
    can_get_normed_counts(adata, raise_error=True)
    if "normed_counts" in adata.layers.keys():
        return
    adata.layers["normed_counts"] = adata.X / adata.obsm["size_factors"][:, None]

sqerror

Module to construct the sqerror layer.

can_get_sqerror_layer(adata, raise_error=False)

Check if the squared error layer can be reconstructed.

Parameters:

Name Type Description Default
adata AnnData

The local AnnData object.

required
raise_error bool

If True, raise an error if the squared error layer cannot be reconstructed.

False

Returns:

Type Description
bool

True if the squared error layer can be reconstructed, False otherwise.

Raises:

Type Description
ValueError

If the squared error layer cannot be reconstructed and raise_error is True.

Source code in fedpydeseq2/core/utils/layers/build_layers/sqerror.py
def can_get_sqerror_layer(adata: ad.AnnData, raise_error: bool = False) -> bool:
    """Check if the squared error layer can be reconstructed.

    Parameters
    ----------
    adata : ad.AnnData
        The local AnnData object.

    raise_error : bool, optional
        If True, raise an error if the squared error layer cannot be reconstructed.

    Returns
    -------
    bool
        True if the squared error layer can be reconstructed, False otherwise.

    Raises
    ------
    ValueError
        If the squared error layer cannot be reconstructed and raise_error is True.

    """
    if "sqerror" in adata.layers.keys():
        return True
    try:
        has_normed_counts = can_get_normed_counts(adata, raise_error=raise_error)
    except ValueError as normed_counts_error:
        raise ValueError(
            f"Error while checking if normed_counts can be"
            f" reconstructed: {normed_counts_error}"
        ) from normed_counts_error

    has_cell_means = "cell_means" in adata.varm.keys()
    has_cell_obs = "cells" in adata.obs.keys()
    if not has_normed_counts or not has_cell_means or not has_cell_obs:
        if raise_error:
            raise ValueError(
                "Local adata must contain the normed_counts layer, the cells obs, "
                "and the cell_means varm to compute the squared error layer."
                " Here are the keys present in the adata: "
                f"obs : {adata.obs.keys()}, varm : {adata.varm.keys()}"
            )
        return False
    return True
set_sqerror_layer(local_adata)

Compute the squared error between the normalized counts and the trimmed mean.

Parameters:

Name Type Description Default
local_adata AnnData

Local AnnData. It is expected to have the following fields: - layers["normed_counts"]: the normalized counts. - varm["cell_means"]: the trimmed mean. - obs["cells"]: the cells.

required
Source code in fedpydeseq2/core/utils/layers/build_layers/sqerror.py
def set_sqerror_layer(local_adata: ad.AnnData):
    """Compute the squared error between the normalized counts and the trimmed mean.

    Parameters
    ----------
    local_adata : ad.AnnData
        Local AnnData. It is expected to have the following fields:
        - layers["normed_counts"]: the normalized counts.
        - varm["cell_means"]: the trimmed mean.
        - obs["cells"]: the cells.

    """
    can_get_sqerror_layer(local_adata, raise_error=True)
    if "sqerror" in local_adata.layers.keys():
        return
    cell_means = local_adata.varm["cell_means"]
    set_normed_counts(local_adata)
    if isinstance(cell_means, pd.DataFrame):
        cells = local_adata.obs["cells"]
        # restrict to the cells that are in the cell means columns
        cells = cells[cells.isin(cell_means.columns)]
        qmat = cell_means[cells].T
        qmat.index = cells.index

        # initialize wiht nans
        layer = np.full_like(local_adata.layers["normed_counts"], np.nan)
        indices = local_adata.obs_names.get_indexer(qmat.index)
        layer[indices, :] = (
            local_adata[qmat.index, :].layers["normed_counts"] - qmat
        ) ** 2
    else:
        layer = (local_adata.layers["normed_counts"] - cell_means[None, :]) ** 2
    local_adata.layers["sqerror"] = layer

y_hat

Module containing the necessary functions to reconstruct the y_hat layer.

can_get_y_hat(local_adata, raise_error=False)

Check if the y_hat layer can be reconstructed.

Parameters:

Name Type Description Default
local_adata AnnData

The local AnnData object.

required
raise_error bool

If True, raise an error if the y_hat layer cannot be reconstructed.

False

Returns:

Type Description
bool

True if the y_hat layer can be reconstructed, False otherwise.

Raises:

Type Description
ValueError

If the y_hat layer cannot be reconstructed and raise_error is True.

Source code in fedpydeseq2/core/utils/layers/build_layers/y_hat.py
def can_get_y_hat(local_adata: ad.AnnData, raise_error: bool = False) -> bool:
    """Check if the y_hat layer can be reconstructed.

    Parameters
    ----------
    local_adata : ad.AnnData
        The local AnnData object.

    raise_error : bool, optional
        If True, raise an error if the y_hat layer cannot be reconstructed.

    Returns
    -------
    bool
        True if the y_hat layer can be reconstructed, False otherwise.

    Raises
    ------
    ValueError
        If the y_hat layer cannot be reconstructed and raise_error is True.

    """
    if "_y_hat" in local_adata.layers.keys():
        return True
    has_design_matrix = "design_matrix" in local_adata.obsm.keys()
    has_beta_rough_dispersions = "_beta_rough_dispersions" in local_adata.varm.keys()
    if not has_design_matrix or not has_beta_rough_dispersions:
        if raise_error:
            raise ValueError(
                "Local adata must contain the design matrix obsm "
                "and the _beta_rough_dispersions varm to compute the y_hat layer."
                " Here are the keys present in the local adata: "
                f"obsm : {local_adata.obsm.keys()} and varm : {local_adata.varm.keys()}"
            )
        return False
    return True
set_y_hat(local_adata)

Reconstruct the y_hat layer.

Parameters:

Name Type Description Default
local_adata AnnData

The local AnnData object.

required
Source code in fedpydeseq2/core/utils/layers/build_layers/y_hat.py
def set_y_hat(local_adata: ad.AnnData):
    """Reconstruct the y_hat layer.

    Parameters
    ----------
    local_adata : ad.AnnData
        The local AnnData object.

    """
    can_get_y_hat(local_adata, raise_error=True)
    if "_y_hat" in local_adata.layers.keys():
        return
    y_hat = (
        local_adata.obsm["design_matrix"].to_numpy()
        @ local_adata.varm["_beta_rough_dispersions"].T
    )
    local_adata.layers["_y_hat"] = y_hat

build_refit_adata

set_basic_refit_adata(self)

Set the basic refit adata from the local adata.

This function checks that the local adata is loaded and the replaced genes are computed and stored in the varm field. It then sets the refit adata from the local adata.

Parameters:

Name Type Description Default
self Any

The object containing the local adata and the refit adata.

required
Source code in fedpydeseq2/core/utils/layers/build_refit_adata.py
def set_basic_refit_adata(self: Any):
    """Set the basic refit adata from the local adata.

    This function checks that the local adata is loaded and the replaced
    genes are computed and stored in the varm field. It then sets the refit
    adata from the local adata.

    Parameters
    ----------
    self : Any
        The object containing the local adata and the refit adata.

    """
    assert (
        self.local_adata is not None
    ), "Local adata must be loaded before setting the refit adata."
    assert (
        "replaced" in self.local_adata.varm.keys()
    ), "Replaced genes must be computed before setting the refit adata."

    genes_to_replace = pd.Series(
        self.local_adata.varm["replaced"], index=self.local_adata.var_names
    )
    if self.refit_adata is None:
        self.refit_adata = self.local_adata[:, genes_to_replace].copy()
        # Clear the varm field of the refit adata
        self.refit_adata.varm = None
    elif "refitted" not in self.local_adata.varm.keys():
        self.refit_adata.X = self.local_adata[:, genes_to_replace].X.copy()
        self.refit_adata.obsm = self.local_adata.obsm
    else:
        genes_to_refit = pd.Series(
            self.local_adata.varm["refitted"], index=self.local_adata.var_names
        )
        self.refit_adata.X = self.local_adata[:, genes_to_refit].X.copy()
        self.refit_adata.obsm = self.local_adata.obsm

set_imputed_counts_refit_adata(self)

Set the imputed counts in the refit adata.

This function checks that the refit adata, the local adata, the replaced genes, the trimmed mean normed counts, the size factors, the cooks G cutoff, and the replaceable genes are computed and stored in the appropriate fields. It then sets the imputed counts in the refit adata.

Note that this function must be run on an object which already contains a refit_adata, whose counts, obsm and uns have been set with the set_basic_refit_adata function.

Parameters:

Name Type Description Default
self Any

The object containing the refit adata, the local adata, the replaced genes, the trimmed mean normed counts, the size factors, the cooks G cutoff, and the replaceable genes.

required
Source code in fedpydeseq2/core/utils/layers/build_refit_adata.py
def set_imputed_counts_refit_adata(self: Any):
    """Set the imputed counts in the refit adata.

    This function checks that the refit adata, the local adata, the replaced
    genes, the trimmed mean normed counts, the size factors, the cooks G cutoff,
    and the replaceable genes are computed and stored in the appropriate fields.
    It then sets the imputed counts in the refit adata.

    Note that this function must be run on an object which already contains
    a refit_adata, whose counts, obsm and uns have been set with the
    `set_basic_refit_adata` function.

    Parameters
    ----------
    self : Any
        The object containing the refit adata, the local adata, the replaced
        genes, the trimmed mean normed counts, the size factors, the cooks G
        cutoff, and the replaceable genes.

    """
    assert (
        self.refit_adata is not None
    ), "Refit adata must be loaded before setting the imputed counts."
    assert (
        self.local_adata is not None
    ), "Local adata must be loaded before setting the imputed counts."
    assert (
        "replaced" in self.local_adata.varm.keys()
    ), "Replaced genes must be computed before setting the imputed counts."
    assert (
        "_trimmed_mean_normed_counts" in self.refit_adata.varm.keys()
    ), "Trimmed mean normed counts must be computed before setting the imputed counts."
    assert (
        "size_factors" in self.refit_adata.obsm.keys()
    ), "Size factors must be computed before setting the imputed counts."
    assert (
        "_where_cooks_g_cutoff" in self.local_adata.uns.keys()
    ), "Cooks G cutoff must be computed before setting the imputed counts."
    assert (
        "replaceable" in self.refit_adata.obsm.keys()
    ), "Replaceable genes must be computed before setting the imputed counts."

    trimmed_mean_normed_counts = self.refit_adata.varm["_trimmed_mean_normed_counts"]

    replacement_counts = pd.DataFrame(
        self.refit_adata.obsm["size_factors"][:, None] * trimmed_mean_normed_counts,
        columns=self.refit_adata.var_names,
        index=self.refit_adata.obs_names,
    ).astype(int)

    idx = np.zeros(self.local_adata.shape, dtype=bool)
    idx[self.local_adata.uns["_where_cooks_g_cutoff"]] = True

    # Restrict to the genes to replace
    if "refitted" not in self.local_adata.varm.keys():
        idx = idx[:, self.local_adata.varm["replaced"]]
    else:
        idx = idx[:, self.local_adata.varm["refitted"]]

    # Replace the counts
    self.refit_adata.X[
        self.refit_adata.obsm["replaceable"][:, None] & idx
    ] = replacement_counts.values[self.refit_adata.obsm["replaceable"][:, None] & idx]

cooks_layer

can_skip_local_cooks_preparation(self)

Check if the Cook's distance is in the layers to save.

This function checks if the Cook's distance is in the layers to save.

Parameters:

Name Type Description Default
self Any

The object.

required

Returns:

Name Type Description
bool bool

Whether the Cook's distance is in the layers to save.

Source code in fedpydeseq2/core/utils/layers/cooks_layer.py
def can_skip_local_cooks_preparation(self: Any) -> bool:
    """Check if the Cook's distance is in the layers to save.

    This function checks if the Cook's distance is in the layers to save.

    Parameters
    ----------
    self : Any
        The object.

    Returns
    -------
    bool:
        Whether the Cook's distance is in the layers to save.

    """
    only_from_disk = (
        not hasattr(self, "save_layers_to_disk") or self.save_layers_to_disk
    )
    if only_from_disk and "cooks" in self.local_adata.layers.keys():
        return True
    if hasattr(self, "layers_to_save_on_disk"):
        layers_to_save_on_disk = self.layers_to_save_on_disk
        if (
            layers_to_save_on_disk is not None
            and "local_adata" in layers_to_save_on_disk
            and layers_to_save_on_disk["local_adata"] is not None
            and "cooks" in layers_to_save_on_disk["local_adata"]
        ):
            return True
    return False

make_hat_matrix_summands_batch(design_matrix, size_factors, beta, dispersions, min_mu)

Make the local hat matrix.

This is quite similar to the make_irls_summands_batch function, but it does not require the counts, and returns only the H matrix.

This is used in the final step of the IRLS algorithm to compute the local hat matrix.

Parameters:

Name Type Description Default
design_matrix ndarray

The design matrix, of shape (n_obs, n_params).

required
size_factors ndarray

The size factors, of shape (n_obs).

required
beta ndarray

The log fold change matrix, of shape (batch_size, n_params).

required
dispersions ndarray

The dispersions, of shape (batch_size).

required
min_mu float

Lower bound on estimated means, to ensure numerical stability.

required

Returns:

Name Type Description
H ndarray

The H matrix, of shape (batch_size, n_params, n_params).

Source code in fedpydeseq2/core/utils/layers/cooks_layer.py
def make_hat_matrix_summands_batch(
    design_matrix: np.ndarray,
    size_factors: np.ndarray,
    beta: np.ndarray,
    dispersions: np.ndarray,
    min_mu: float,
) -> np.ndarray:
    """Make the local hat matrix.

    This is quite similar to the make_irls_summands_batch function, but it does not
    require the counts, and returns only the H matrix.

    This is used in the final step of the IRLS algorithm to compute the local hat
    matrix.

    Parameters
    ----------
    design_matrix : np.ndarray
        The design matrix, of shape (n_obs, n_params).
    size_factors : np.ndarray
        The size factors, of shape (n_obs).
    beta : np.ndarray
        The log fold change matrix, of shape (batch_size, n_params).
    dispersions : np.ndarray
        The dispersions, of shape (batch_size).
    min_mu : float
        Lower bound on estimated means, to ensure numerical stability.


    Returns
    -------
    H : np.ndarray
        The H matrix, of shape (batch_size, n_params, n_params).
    """
    mu = size_factors[:, None] * np.exp(design_matrix @ beta.T)

    mu = np.maximum(mu, min_mu)

    W = mu / (1.0 + mu * dispersions[None, :])

    H = (design_matrix.T[:, :, None] * W).transpose(2, 0, 1) @ design_matrix[None, :, :]

    return H

prepare_cooks_agg(method)

Decorate the aggregation step to compute the Cook's distance.

This decorator is supposed to be placed on the aggregation step just before a local step which needs the "cooks" layer. The decorator will check if the shared state contains the necessary keys for the Cook's distance computation. If this is not the case, then the Cook's distance must have been saved in the layers_to_save. It will compute the Cook's dispersion, the hat matrix inverse, and then call the method.

It will add the following keys to the shared state: - cooks_dispersions - global_hat_matrix_inv

Parameters:

Name Type Description Default
method Callable

The aggregation method to decorate. It must have the following signature: method(self, shared_states: Optional[list], **method_parameters).

required

Returns:

Name Type Description
Callable

The decorated method.

Source code in fedpydeseq2/core/utils/layers/cooks_layer.py
def prepare_cooks_agg(method: Callable):
    """Decorate the aggregation step to compute the Cook's distance.

    This decorator is supposed to be placed on the aggregation step just before
    a local step which needs the "cooks" layer. The decorator will check if the
    shared state contains the necessary keys for the Cook's distance computation.
    If this is not the case, then the Cook's distance must have been saved in the
    layers_to_save.
    It will compute the Cook's dispersion, the hat matrix inverse, and then call
    the method.

    It will add the following keys to the shared state:
    - cooks_dispersions
    - global_hat_matrix_inv

    Parameters
    ----------
    method : Callable
        The aggregation method to decorate.
        It must have the following signature:
        method(self, shared_states: Optional[list], **method_parameters).

    Returns
    -------
    Callable:
        The decorated method.

    """

    @wraps(method)
    def method_inner(
        self,
        shared_states: list | None,
        **method_parameters,
    ):
        # Check that the shared state contains the necessary keys
        # for the Cook's distance computation
        # If this is not the case, then the cooks distance must have
        # been saved in the layers_to_save

        try:
            assert isinstance(shared_states, list)
            assert "n_samples" in shared_states[0].keys()
            assert "varEst" in shared_states[0].keys()
            assert "mean_normed_counts" in shared_states[0].keys()
            assert "local_hat_matrix" in shared_states[0].keys()
        except AssertionError as assertion_error:
            only_from_disk = (
                not hasattr(self, "save_layers_to_disk") or self.save_layers_to_disk
            )
            if only_from_disk:
                return method(self, shared_states, **method_parameters)
            elif isinstance(shared_states, list) and shared_states[0]["_skip_cooks"]:
                return method(self, shared_states, **method_parameters)
            raise ValueError(
                "The shared state does not contain the necessary keys for"
                "the Cook's distance computation."
            ) from assertion_error

        assert isinstance(shared_states, list)

        # ---- Step 1: Compute Cooks dispersion ---- #

        n_sample_tot = sum(
            [shared_state["n_samples"] for shared_state in shared_states]
        )
        varEst = shared_states[0]["varEst"]
        mean_normed_counts = (
            np.array(
                [
                    (shared_state["mean_normed_counts"] * shared_state["n_samples"])
                    for shared_state in shared_states
                ]
            ).sum(axis=0)
            / n_sample_tot
        )
        mask_zero = mean_normed_counts == 0
        mask_varEst_zero = varEst == 0
        alpha = varEst - mean_normed_counts
        alpha[~mask_zero] = alpha[~mask_zero] / mean_normed_counts[~mask_zero] ** 2
        alpha[mask_varEst_zero & mask_zero] = np.nan
        alpha[mask_varEst_zero & (~mask_zero)] = (
            np.inf * alpha[mask_varEst_zero & (~mask_zero)]
        )

        # cannot use the typical min_disp = 1e-8 here or else all counts in the same
        # group as the outlier count will get an extreme Cook's distance
        minDisp = 0.04
        alpha = cast(pd.Series, np.maximum(alpha, minDisp))

        # --- Step 2: Compute the hat matrix inverse --- #

        global_hat_matrix = sum([state["local_hat_matrix"] for state in shared_states])
        n_jobs, joblib_verbosity, joblib_backend, batch_size = get_joblib_parameters(
            self
        )
        ridge_factor = np.diag(np.repeat(1e-6, global_hat_matrix.shape[1]))
        with parallel_backend(joblib_backend):
            res = Parallel(n_jobs=n_jobs, verbose=joblib_verbosity)(
                delayed(np.linalg.inv)(hat_matrices + ridge_factor)
                for hat_matrices in np.split(
                    global_hat_matrix,
                    range(
                        batch_size,
                        len(global_hat_matrix),
                        batch_size,
                    ),
                )
            )

        global_hat_matrix_inv = np.concatenate(res)

        # ---- Step 3: Run the method ---- #

        shared_state = method(self, shared_states, **method_parameters)

        # ---- Step 4: Save the Cook's dispersion and the hat matrix inverse ---- #

        shared_state["cooks_dispersions"] = alpha
        shared_state["global_hat_matrix_inv"] = global_hat_matrix_inv

        return shared_state

    return method_inner

prepare_cooks_local(method)

Decorate the local method just preceding a local method needing cooks.

This method is only applied if the Cooks layer is not present or must not be saved between steps.

This step is used to compute the local hat matrix and the mean normed counts.

Before the method is called, the varEst must be accessed from the shared state, or from the local adata if it is not present in the shared state.

The local hat matrix and the mean normed counts are computed, and the following keys are added to the shared state: - local_hat_matrix - mean_normed_counts - n_samples - varEst

Parameters:

Name Type Description Default
method Callable

The remote_data method to decorate.

required

Returns:

Name Type Description
Callable

The decorated method.

Source code in fedpydeseq2/core/utils/layers/cooks_layer.py
def prepare_cooks_local(method: Callable):
    """Decorate the local method just preceding a local method needing cooks.

    This method is only applied if the Cooks layer is not present or must not be
    saved between steps.

    This step is used to compute the local hat matrix and the mean normed counts.

    Before the method is called, the varEst must be accessed from the shared state,
    or from the local adata if it is not present in the shared state.

    The local hat matrix and the mean normed counts are computed, and the following
    keys are added to the shared state:
    - local_hat_matrix
    - mean_normed_counts
    - n_samples
    - varEst

    Parameters
    ----------
    method : Callable
        The remote_data method to decorate.

    Returns
    -------
    Callable:
        The decorated method.

    """

    @wraps(method)
    def method_inner(
        self,
        data_from_opener: ad.AnnData,
        shared_state: Any = None,
        **method_parameters,
    ):
        # ---- Step 0: If can skip, we skip ---- #
        if can_skip_local_cooks_preparation(self):
            shared_state = method(
                self, data_from_opener, shared_state, **method_parameters
            )
            shared_state["_skip_cooks"] = True
            return shared_state

        # ---- Step 1: Access varEst ---- #

        if "varEst" in self.local_adata.varm.keys():
            varEst = self.local_adata.varm["varEst"]
        else:
            assert "varEst" in shared_state
            varEst = shared_state["varEst"]
            self.local_adata.varm["varEst"] = varEst

        # ---- Step 2: Run the method ---- #
        shared_state = method(self, data_from_opener, shared_state, **method_parameters)

        # ---- Step 3: Compute the local hat matrix ---- #

        n_jobs, joblib_verbosity, joblib_backend, batch_size = get_joblib_parameters(
            self
        )
        # Compute hat matrix
        (
            gene_names,
            design_matrix,
            size_factors,
            counts,
            dispersions,
            beta,
        ) = get_lfc_utils_from_gene_mask_adata(
            self.local_adata,
            None,
            "dispersions",
            lfc_param_name="LFC",
        )

        with parallel_backend(joblib_backend):
            res = Parallel(n_jobs=n_jobs, verbose=joblib_verbosity)(
                delayed(make_hat_matrix_summands_batch)(
                    design_matrix,
                    size_factors,
                    beta[i : i + batch_size],
                    dispersions[i : i + batch_size],
                    self.min_mu,
                )
                for i in range(0, len(beta), batch_size)
            )

        if len(res) == 0:
            H = np.zeros((0, beta.shape[1], beta.shape[1]))
        else:
            H = np.concatenate(res)

        shared_state["local_hat_matrix"] = H

        # ---- Step 4: Compute the mean normed counts ---- #

        mean_normed_counts = self.local_adata.layers["normed_counts"].mean(axis=0)

        shared_state["mean_normed_counts"] = mean_normed_counts
        shared_state["n_samples"] = self.local_adata.n_obs
        shared_state["varEst"] = varEst
        shared_state["_skip_cooks"] = False

        return shared_state

    return method_inner

joblib_utils

get_joblib_parameters(x)

Get the joblib parameters from an object, and return them as a tuple.

If the object has no joblib parameters, default values are returned.

Parameters:

Name Type Description Default
x Any

Object from which to extract the joblib parameters.

required

Returns:

Name Type Description
n_jobs int

Number of jobs to run in parallel.

joblib_verbosity int

Verbosity level of joblib.

joblib_backend str

Joblib backend.

batch_size int

Batch size for the IRLS algorithm.

Source code in fedpydeseq2/core/utils/layers/joblib_utils.py
def get_joblib_parameters(x: Any) -> tuple[int, int, str, int]:
    """
    Get the joblib parameters from an object, and return them as a tuple.

    If the object has no joblib parameters, default values are returned.

    Parameters
    ----------
    x: Any
        Object from which to extract the joblib parameters.

    Returns
    -------
    n_jobs: int
        Number of jobs to run in parallel.
    joblib_verbosity: int
        Verbosity level of joblib.
    joblib_backend: str
        Joblib backend.
    batch_size: int
        Batch size for the IRLS algorithm.

    """
    n_jobs = x.num_jobs if hasattr(x, "num_jobs") else 1

    joblib_verbosity = x.joblib_verbosity if hasattr(x, "joblib_verbosity") else 0
    joblib_backend = x.joblib_backend if hasattr(x, "joblib_backend") else "loky"
    batch_size = x.irls_batch_size if hasattr(x, "irls_batch_size") else 100
    return n_jobs, joblib_verbosity, joblib_backend, batch_size

reconstruct_adatas_decorator

Module containing a decorator to handle simple layers.

This wrapper is used to load and save simple layers from the adata object. These simple layers are defined in SIMPLE_LAYERS.

check_and_load_layers(self, adata_name, layers_to_load, shared_state, only_from_disk)

Check and load layers for a given adata_name.

This function checks the availability of the layers to load and loads them, for the adata_name adata.

Parameters:

Name Type Description Default
self Any

The object containing the adata.

required
adata_name str

The name of the adata to load the layers into.

required
layers_to_load dict[str, Optional[list[str]]]

The layers to load for each adata. It must have adata_name as a key.

required
shared_state Optional[dict]

The shared state.

required
only_from_disk bool

Whether to load only the layers from disk.

required
Source code in fedpydeseq2/core/utils/layers/reconstruct_adatas_decorator.py
def check_and_load_layers(
    self: Any,
    adata_name: str,
    layers_to_load: dict[str, list[str] | None],
    shared_state: dict | None,
    only_from_disk: bool,
):
    """Check and load layers for a given adata_name.

    This function checks the availability of the layers to load
    and loads them, for the adata_name adata.

    Parameters
    ----------
    self : Any
        The object containing the adata.
    adata_name : str
        The name of the adata to load the layers into.
    layers_to_load : dict[str, Optional[list[str]]]
        The layers to load for each adata. It must have adata_name
        as a key.
    shared_state : Optional[dict]
        The shared state.
    only_from_disk : bool
        Whether to load only the layers from disk.

    """
    adata = getattr(self, adata_name)
    layers_to_load_adata = layers_to_load[adata_name]
    available_layers_adata = get_available_layers(
        adata,
        shared_state,
        refit=adata_name == "refit_adata",
        all_layers_from_disk=only_from_disk,
    )
    if layers_to_load_adata is None:
        layers_to_load_adata = available_layers_adata
    else:
        assert np.all(
            [layer in available_layers_adata for layer in layers_to_load_adata]
        )
    if adata is None:
        return
    assert layers_to_load_adata is not None
    n_jobs, joblib_verbosity, joblib_backend, batch_size = get_joblib_parameters(self)
    load_layers(
        adata=adata,
        shared_state=shared_state,
        layers_to_load=layers_to_load_adata,
        n_jobs=n_jobs,
        joblib_verbosity=joblib_verbosity,
        joblib_backend=joblib_backend,
        batch_size=batch_size,
    )

reconstruct_adatas(method)

Decorate a method to load layers and remove them before saving the state.

This decorator loads the layers from the data_from_opener and the adata object before calling the method. It then removes the layers from the adata object after the method is called.

The object self CAN have the following attributes:

  • save_layers_to_disk: if this argument exists or is True, we save all the layers on disk, without removing them at the end of each local step. If it is False, we remove all layers that must be removed at the end of each local step. This argument is prevalent above all others described below.

  • layers_to_save_on_disk: if this argument exists, contains the layers that must be saved on disk at EVERY local step. It can be either None (in which case the default behaviour is to save no layers) or a dictionary with a refit_adata and local_adata key. The associated values contain either None (no layers) or a list of layers to save at each step.

This decorator adds two parameters to each method decorated with it: - layers_to_load - layers_to_save_on_disk

If the layers_to_load is None, the default is to load all available layers. Else, we only load the layers specified in the layers_to_load argument.

The layers_to_save_on_disk argument is ADDED to the layers_to_save_on_disk attribute of self for the duration of the method and then removed. That way, the inner method can access the names of the layers_to_save_on_disk which will effectively be saved at the end of the step.

Parameters:

Name Type Description Default
method Callable

The method to decorate. This method is expected to have the following signature: method(self, data_from_opener: ad.AnnData, shared_state: Any, **method_parameters).

required

Returns:

Type Description
Callable

The decorated method, which loads the simple layers before calling the method and removes the simple layers after the method is called.

Source code in fedpydeseq2/core/utils/layers/reconstruct_adatas_decorator.py
def reconstruct_adatas(method: Callable):
    """Decorate a method to load layers and remove them before saving the state.

    This decorator loads the layers from the data_from_opener and the adata
    object before calling the method. It then removes the layers from the adata
    object after the method is called.

    The object self CAN have the following attributes:

    - save_layers_to_disk: if this argument exists or is True, we save all the layers
    on disk, without removing them at the end of each local step. If it is False,
    we remove all layers that must be removed at the end of each local step.
    This argument is prevalent above all others described below.

    - layers_to_save_on_disk: if this argument exists, contains the layers that
    must be saved on disk at EVERY local step. It can be either None (in which
    case the default behaviour is to save no layers) or a dictionary with a refit_adata
    and local_adata key. The associated values contain either None (no layers) or
    a list of layers to save at each step.

    This decorator adds two parameters to each method decorated with it:
    - layers_to_load
    - layers_to_save_on_disk

    If the layers_to_load is None, the default is to load all available layers.
    Else, we only load the layers specified in the layers_to_load argument.

    The layers_to_save_on_disk argument is ADDED to the layers_to_save_on_disk attribute
    of self for the duration of the method and then removed. That way, the inner
    method can access the names of the layers_to_save_on_disk which will effectively
    be saved at the end of the step.

    Parameters
    ----------
    method : Callable
        The method to decorate. This method is expected to have the following signature:
        method(self, data_from_opener: ad.AnnData, shared_state: Any,
         **method_parameters).

    Returns
    -------
    Callable
        The decorated method, which loads the simple layers before calling the method
        and removes the simple layers after the method is called.

    """

    @wraps(method)
    def method_inner(
        self,
        data_from_opener: ad.AnnData,
        shared_state: Any = None,
        layers_to_load: LayersToLoadSaveType = None,
        layers_to_save_on_disk: LayersToLoadSaveType = None,
        **method_parameters,
    ):
        if layers_to_load is None:
            layers_to_load = {"local_adata": None, "refit_adata": None}
        if hasattr(self, "layers_to_save_on_disk"):
            if self.layers_to_save_on_disk is None:
                global_layers_to_save_on_disk = None
            else:
                global_layers_to_save_on_disk = self.layers_to_save_on_disk.copy()

            if global_layers_to_save_on_disk is None:
                self.layers_to_save_on_disk = {"local_adata": [], "refit_adata": []}
        else:
            self.layers_to_save_on_disk = {"local_adata": [], "refit_adata": []}

        if layers_to_save_on_disk is None:
            layers_to_save_on_disk = {"local_adata": [], "refit_adata": []}

        # Set the layers_to_save_on_disk attribute to the union of the layers specified
        # in the argument and those in the attribute, to be accessed by the method.
        assert isinstance(self.layers_to_save_on_disk, dict)
        for adata_name in ["local_adata", "refit_adata"]:
            if self.layers_to_save_on_disk[adata_name] is None:
                self.layers_to_save_on_disk[adata_name] = []
            if layers_to_save_on_disk[adata_name] is None:
                layers_to_save_on_disk[adata_name] = []
            self.layers_to_save_on_disk[adata_name] = list(
                set(
                    layers_to_save_on_disk[adata_name]
                    + self.layers_to_save_on_disk[adata_name]
                )
            )

        # Check that the layers_to_load and layers_to_save are valid
        assert set(layers_to_load.keys()) == {"local_adata", "refit_adata"}
        assert set(self.layers_to_save_on_disk.keys()) == {"local_adata", "refit_adata"}

        # Load the counts of the adata
        if self.local_adata is not None:
            if self.local_adata.X is None:
                self.local_adata.X = data_from_opener.X

        # Load the available layers
        only_from_disk = (
            not hasattr(self, "save_layers_to_disk") or self.save_layers_to_disk
        )

        # Start by loading the local adata
        check_and_load_layers(
            self, "local_adata", layers_to_load, shared_state, only_from_disk
        )

        # Create the refit adata
        reconstruct_refit_adata_without_layers(self)

        # Load the layers of the refit adata
        check_and_load_layers(
            self, "refit_adata", layers_to_load, shared_state, only_from_disk
        )

        # Apply the method
        shared_state = method(self, data_from_opener, shared_state, **method_parameters)

        # Remove all layers which must not be saved on disk
        for adata_name in ["local_adata", "refit_adata"]:
            adata = getattr(self, adata_name)
            if adata is None:
                continue
            if only_from_disk:
                layers_to_save_on_disk_adata: list | None = list(adata.layers.keys())
            else:
                layers_to_save_on_disk_adata = self.layers_to_save_on_disk[adata_name]
                assert layers_to_save_on_disk_adata is not None
                for layer in layers_to_save_on_disk_adata:
                    if layer not in adata.layers.keys():
                        print("Warning: layer not in adata: ", layer)
            assert layers_to_save_on_disk_adata is not None
            remove_layers(
                adata=adata,
                layers_to_save_on_disk=layers_to_save_on_disk_adata,
                refit=adata_name == "refit_adata",
            )

        # Reset the layers_to_save_on_disk attribute
        try:
            self.layers_to_save_on_disk = global_layers_to_save_on_disk
        except NameError:
            del self.layers_to_save_on_disk

        return shared_state

    return method_inner

reconstruct_refit_adata_without_layers(self)

Reconstruct the refit adata without the layers.

This function reconstructs the refit adata without the layers. It is used to avoid the counts and the obsm being loaded uselessly in the refit_adata.

Parameters:

Name Type Description Default
self Any

The object containing the adata.

required
Source code in fedpydeseq2/core/utils/layers/reconstruct_adatas_decorator.py
def reconstruct_refit_adata_without_layers(self: Any):
    """Reconstruct the refit adata without the layers.

    This function reconstructs the refit adata without the layers.
    It is used to avoid the counts and the obsm being loaded uselessly in the
    refit_adata.

    Parameters
    ----------
    self : Any
        The object containing the adata.

    """
    if self.refit_adata is None:
        return
    if self.local_adata is not None and "replaced" in self.local_adata.varm.keys():
        set_basic_refit_adata(self)
    if self.local_adata is not None and "refitted" in self.local_adata.varm.keys():
        set_imputed_counts_refit_adata(self)

utils

get_available_layers(adata, shared_state, refit=False, all_layers_from_disk=False)

Get the available layers in the adata.

Parameters:

Name Type Description Default
adata Optional[AnnData]

The local adata.

required
shared_state dict

The shared state containing the Cook's dispersion values.

required
refit bool

Whether to refit the layers.

False
all_layers_from_disk bool

Whether to get all layers from disk.

False

Returns:

Type Description
list[str]

List of available layers.

Source code in fedpydeseq2/core/utils/layers/utils.py
def get_available_layers(
    adata: ad.AnnData | None,
    shared_state: dict | None,
    refit: bool = False,
    all_layers_from_disk: bool = False,
) -> list[str]:
    """Get the available layers in the adata.

    Parameters
    ----------
    adata : Optional[ad.AnnData]
        The local adata.

    shared_state : dict
        The shared state containing the Cook's dispersion values.

    refit : bool
        Whether to refit the layers.

    all_layers_from_disk : bool
        Whether to get all layers from disk.

    Returns
    -------
    list[str]
        List of available layers.

    """
    if adata is None:
        return []
    if all_layers_from_disk:
        return list(adata.layers.keys())
    available_layers = []
    if can_get_normed_counts(adata, raise_error=False):
        available_layers.append("normed_counts")
    if can_get_y_hat(adata, raise_error=False):
        available_layers.append("_y_hat")
    if can_get_mu_hat(adata, raise_error=False):
        available_layers.append("_mu_hat")
    if can_get_fit_lin_mu_hat(adata, raise_error=False):
        available_layers.append("_fit_lin_mu_hat")
    if can_get_sqerror_layer(adata, raise_error=False):
        available_layers.append("sqerror")
    if not refit and can_set_cooks_layer(
        adata, shared_state=shared_state, raise_error=False
    ):
        available_layers.append("cooks")
    if not refit and can_set_hat_diagonals_layer(
        adata, shared_state=shared_state, raise_error=False
    ):
        available_layers.append("_hat_diagonals")
    if can_set_mu_layer(
        adata, lfc_param_name="LFC", mu_param_name="_mu_LFC", raise_error=False
    ):
        available_layers.append("_mu_LFC")
    if can_set_mu_layer(
        adata,
        lfc_param_name="_mu_hat_LFC",
        mu_param_name="_irls_mu_hat",
        raise_error=False,
    ):
        available_layers.append("_irls_mu_hat")

    return available_layers

load_layers(adata, shared_state, layers_to_load, n_jobs=1, joblib_verbosity=0, joblib_backend='loky', batch_size=100)

Load the simple layers from the data_from_opener and the adata object.

This function loads the layers in the layers_to_load attribute in the adata object.

Parameters:

Name Type Description Default
adata AnnData

The AnnData object to load the layers into.

required
shared_state dict

The shared state containing the Cook's dispersion values.

required
layers_to_load list[str]

The list of layers to load.

required
n_jobs int

The number of jobs to use for parallel processing.

1
joblib_verbosity int

The verbosity level of joblib.

0
joblib_backend str

The joblib backend to use.

'loky'
batch_size int

The batch size for parallel processing.

100
Source code in fedpydeseq2/core/utils/layers/utils.py
def load_layers(
    adata: ad.AnnData,
    shared_state: dict | None,
    layers_to_load: list[str],
    n_jobs: int = 1,
    joblib_verbosity: int = 0,
    joblib_backend: str = "loky",
    batch_size: int = 100,
):
    """Load the simple layers from the data_from_opener and the adata object.

    This function loads the layers in the layers_to_load attribute in the
    adata object.

    Parameters
    ----------
    adata : ad.AnnData
        The AnnData object to load the layers into.

    shared_state : dict, optional
        The shared state containing the Cook's dispersion values.

    layers_to_load : list[str]
        The list of layers to load.

    n_jobs : int
        The number of jobs to use for parallel processing.

    joblib_verbosity : int
        The verbosity level of joblib.

    joblib_backend : str
        The joblib backend to use.

    batch_size : int
        The batch size for parallel processing.

    """
    # Assert that all layers are either complex or simple
    assert np.all(
        layer in AVAILABLE_LAYERS for layer in layers_to_load
    ), f"All layers in layers_to_load must be in {AVAILABLE_LAYERS}"

    if "normed_counts" in layers_to_load:
        set_normed_counts(adata=adata)
    if "_mu_LFC" in layers_to_load:
        set_mu_layer(
            local_adata=adata,
            lfc_param_name="LFC",
            mu_param_name="_mu_LFC",
            n_jobs=n_jobs,
            joblib_verbosity=joblib_verbosity,
            joblib_backend=joblib_backend,
            batch_size=batch_size,
        )
    if "_irls_mu_hat" in layers_to_load:
        set_mu_layer(
            local_adata=adata,
            lfc_param_name="_mu_hat_LFC",
            mu_param_name="_irls_mu_hat",
            n_jobs=n_jobs,
            joblib_verbosity=joblib_verbosity,
            joblib_backend=joblib_backend,
            batch_size=batch_size,
        )
    if "sqerror" in layers_to_load:
        set_sqerror_layer(adata)
    if "_y_hat" in layers_to_load:
        set_y_hat(adata)
    if "_fit_lin_mu_hat" in layers_to_load:
        set_fit_lin_mu_hat(adata)
    if "_mu_hat" in layers_to_load:
        set_mu_hat_layer(adata)
    if "_hat_diagonals" in layers_to_load:
        set_hat_diagonals_layer(adata=adata, shared_state=shared_state)
    if "cooks" in layers_to_load:
        set_cooks_layer(adata=adata, shared_state=shared_state)

remove_layers(adata, layers_to_save_on_disk, refit=False)

Remove the simple layers from the adata object.

This function removes the simple layers from the adata object. The layers_to_save parameter can be used to specify which layers to save in the local state. If layers_to_save is None, no layers are saved.

This function also adds all present layers to the _available_layers field in the adata object. This field is used to keep track of the layers that are present in the adata object.

Parameters:

Name Type Description Default
adata AnnData

The AnnData object to remove the layers from.

required
refit bool

Whether the adata object is the refit_adata object.

False
layers_to_save_on_disk list[str]

The list of layers to save. If None, no layers are saved.

required
Source code in fedpydeseq2/core/utils/layers/utils.py
def remove_layers(
    adata: ad.AnnData,
    layers_to_save_on_disk: list[str],
    refit: bool = False,
):
    """Remove the simple layers from the adata object.

    This function removes the simple layers from the adata object. The layers_to_save
    parameter can be used to specify which layers to save in the local state.
    If layers_to_save is None, no layers are saved.

    This function also adds all present layers to the _available_layers field in the
    adata object. This field is used to keep track of the layers that are present in
    the adata object.

    Parameters
    ----------
    adata : ad.AnnData
        The AnnData object to remove the layers from.

    refit : bool
        Whether the adata object is the refit_adata object.

    layers_to_save_on_disk : list[str]
        The list of layers to save. If None, no layers are saved.

    """
    adata.X = None
    if refit:
        adata.obsm = None

    layer_names = list(adata.layers.keys()).copy()
    for layer_name in layer_names:
        if layer_name in layers_to_save_on_disk:
            continue
        del adata.layers[layer_name]

logging

logging_decorators

Module containing decorators to log the input and outputs of a method.

All logging is controlled through a logging configuration file. This configuration file can be either set by the log_config_path attribute of the class, or by the default_config.ini file in the same directory as this module.

get_method_logger(self, method)

Get the method logger from a configuration file.

If the class instance has a log_config_path attribute, the logger is configured with the file at this path.

Parameters:

Name Type Description Default
self Any

The class instance

required
method Callable

The class method.

required

Returns:

Type Description
Logger

The logger instance.

Source code in fedpydeseq2/core/utils/logging/logging_decorators.py
def get_method_logger(self: Any, method: Callable) -> logging.Logger:
    """
    Get the method logger from a configuration file.

    If the class instance has a log_config_path attribute,
    the logger is configured with the file at this path.

    Parameters
    ----------
    self: Any
        The class instance
    method: Callable
        The class method.

    Returns
    -------
    logging.Logger
        The logger instance.
    """
    if hasattr(self, "log_config_path"):
        log_config_path = pathlib.Path(self.log_config_path)
    else:
        log_config_path = pathlib.Path(__file__).parent / "default_config.ini"
    logging.config.fileConfig(log_config_path, disable_existing_loggers=False)
    logger = logging.getLogger(method.__name__)
    return logger

log_remote(method)

Decorate a remote method to log the input and outputs.

This decorator logs the shared state keys with the info level.

Parameters:

Name Type Description Default
method Callable

The method to decorate. This method is expected to have the following signature: method(self, shared_states: Optional[list], **method_parameters).

required

Returns:

Type Description
Callable

The decorated method, which logs the shared state keys with the info level.

Source code in fedpydeseq2/core/utils/logging/logging_decorators.py
def log_remote(method: Callable):
    """
    Decorate a remote method to log the input and outputs.

    This decorator logs the shared state keys with the info level.

    Parameters
    ----------
    method : Callable
        The method to decorate. This method is expected to have the following signature:
        method(self, shared_states: Optional[list], **method_parameters).

    Returns
    -------
    Callable
        The decorated method, which logs the shared state keys with the info level.

    """

    @wraps(method)
    def remote_method_inner(
        self,
        shared_states: list | None,
        **method_parameters,
    ):
        logger = get_method_logger(self, method)
        if shared_states is not None:
            shared_state = shared_states[0]
            if shared_state is not None:
                logger.info(
                    f"First input shared state keys : {list(shared_state.keys())}"
                )
            else:
                logger.info("First input shared state is None.")
        else:
            logger.info("No input shared states.")

        shared_state = method(self, shared_states, **method_parameters)

        if shared_state is not None:
            logger.info(f"Output shared state keys : {list(shared_state.keys())}")
        else:
            logger.info("No output shared state.")

        return shared_state

    return remote_method_inner

log_remote_data(method)

Decorate a remote_data to log the input and outputs.

This decorator logs the shared state keys with the info level, and the different layers of the local_adata and refit_adata with the debug level.

This is done before and after the method call.

Parameters:

Name Type Description Default
method Callable

The method to decorate. This method is expected to have the following signature: method(self, data_from_opener: ad.AnnData, shared_state: Any = None, **method_parameters).

required

Returns:

Type Description
Callable

The decorated method, which logs the shared state keys with the info level and the different layers of the local_adata and refit_adata with the debug level.

Source code in fedpydeseq2/core/utils/logging/logging_decorators.py
def log_remote_data(method: Callable):
    """
    Decorate a remote_data to log the input and outputs.

    This decorator logs the shared state keys with the info level,
    and the different layers of the local_adata and refit_adata with the debug level.

    This is done before and after the method call.

    Parameters
    ----------
    method : Callable
        The method to decorate. This method is expected to have the following signature:
        method(self, data_from_opener: ad.AnnData,
        shared_state: Any = None, **method_parameters).

    Returns
    -------
    Callable
        The decorated method, which logs the shared state keys with the info level
        and the different layers of the local_adata and refit_adata with the debug
        level.
    """

    @wraps(method)
    def remote_method_inner(
        self,
        data_from_opener: ad.AnnData,
        shared_state: Any = None,
        **method_parameters,
    ):
        logger = get_method_logger(self, method)
        logger.info("---- Before running the method ----")
        log_shared_state_adatas(self, method, shared_state)

        shared_state = method(self, data_from_opener, shared_state, **method_parameters)

        logger.info("---- After method ----")
        log_shared_state_adatas(self, method, shared_state)
        return shared_state

    return remote_method_inner

log_save_local_state(method)

Decorate a method to log the size of the local state saved.

This function is destined to decorate the save_local_state method of a class.

It logs the size of the local state saved in the local state path, in MB. This is logged as an info message.

Parameters:

Name Type Description Default
method Callable

The method to decorate. This method is expected to have the following signature: method(self, path: pathlib.Path).

required

Returns:

Type Description
Callable

The decorated method, which logs the size of the local state saved.

Source code in fedpydeseq2/core/utils/logging/logging_decorators.py
def log_save_local_state(method: Callable):
    """
    Decorate a method to log the size of the local state saved.

    This function is destined to decorate the save_local_state method of a class.

    It logs the size of the local state saved in the local state path, in MB.
    This is logged as an info message.

    Parameters
    ----------
    method : Callable
        The method to decorate. This method is expected to have the following signature:
        method(self, path: pathlib.Path).

    Returns
    -------
    Callable
        The decorated method, which logs the size of the local state saved.

    """

    @wraps(method)
    def remote_method_inner(
        self,
        path: pathlib.Path,
    ):
        logger = get_method_logger(self, method)

        output = method(self, path)

        logger.info(
            f"Size of local state saved : "
            f"{os.path.getsize(path) / 1024 / 1024}"
            " MB"
        )

        return output

    return remote_method_inner

log_shared_state_adatas(self, method, shared_state)

Log the information of the local step.

Precisely, log the shared state keys (info), and the different layers of the local_adata and refit_adata (debug).

Parameters:

Name Type Description Default
self Any

The class instance

required
method Callable

The class method.

required
shared_state Optional[dict]

The shared state dictionary, whose keys we log with the info level.

required
Source code in fedpydeseq2/core/utils/logging/logging_decorators.py
def log_shared_state_adatas(self: Any, method: Callable, shared_state: dict | None):
    """
    Log the information of the local step.

    Precisely, log the shared state keys (info),
    and the different layers of the local_adata and refit_adata (debug).

    Parameters
    ----------
    self : Any
        The class instance
    method : Callable
        The class method.
    shared_state : Optional[dict]
        The shared state dictionary, whose keys we log with the info level.

    """
    logger = get_method_logger(self, method)

    if shared_state is not None:
        logger.info(f"Shared state keys : {list(shared_state.keys())}")
    else:
        logger.info("No shared state")

    for adata_name in ["local_adata", "refit_adata"]:
        if hasattr(self, adata_name) and getattr(self, adata_name) is not None:
            adata = getattr(self, adata_name)
            logger.debug(f"{adata_name} layers : {list(adata.layers.keys())}")
            if "_available_layers" in self.local_adata.uns:
                available_layers = self.local_adata.uns["_available_layers"]
                logger.debug(f"{adata_name} available layers : {available_layers}")
            logger.debug(f"{adata_name} uns keys : {list(adata.uns.keys())}")
            logger.debug(f"{adata_name} varm keys : {list(adata.varm.keys())}")
            logger.debug(f"{adata_name} obsm keys : {list(adata.obsm.keys())}")

mle

batch_mle_grad(counts, design, mu, alpha)

Estimate the local gradients wrt dispersions on a batch of genes.

Returns both the gradient of the negative likelihood, and two matrices used to compute the gradient of the Cox-Reid adjustment.

Parameters:

Name Type Description Default
counts ndarray

Raw counts for a set of genes (n_samples x n_genes).

required
design ndarray

Design matrix (n_samples x n_params).

required
mu ndarray

Mean estimation for the NB model (n_samples x n_genes).

required
alpha float

Initial dispersion estimate (nn_genes).

required

Returns:

Name Type Description
grad ndarray

Gradient of the negative log likelihood of the observations counts following :math:NB(\\mu, \\alpha) (n_genes).

M1 ndarray

First summand for the gradient of the CR adjustment (n_genes x n_params x n_params).

M2 ndarray

Second summand for the gradient of the CR adjustment (n_genes x n_params x n_params).

Source code in fedpydeseq2/core/utils/mle.py
def batch_mle_grad(
    counts: np.ndarray, design: np.ndarray, mu: np.ndarray, alpha: np.ndarray
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    r"""Estimate the local gradients wrt dispersions on a batch of genes.

    Returns both the gradient of the negative likelihood, and two matrices used to
    compute the gradient of the Cox-Reid adjustment.


    Parameters
    ----------
    counts : ndarray
        Raw counts for a set of genes (n_samples x n_genes).

    design : ndarray
        Design matrix (n_samples x n_params).

    mu : ndarray
        Mean estimation for the NB model (n_samples x n_genes).

    alpha : float
        Initial dispersion estimate (nn_genes).

    Returns
    -------
    grad : ndarray
        Gradient of the negative log likelihood of the observations counts following
        :math:`NB(\\mu, \\alpha)` (n_genes).

    M1 : ndarray
        First summand for the gradient of the CR adjustment
        (n_genes x n_params x n_params).

    M2 : ndarray
        Second summand for the gradient of the CR adjustment
        (n_genes x n_params x n_params).
    """
    grad = alpha * vec_nb_nll_grad(
        counts,
        mu,
        alpha,
    )  # Need to multiply by alpha to get the gradient wrt log_alpha

    W = mu / (1 + mu * alpha[None, :])

    dW = -(W**2)
    M1 = (design.T[:, :, None] * W).transpose(2, 0, 1) @ design[None, :, :]
    M2 = (design.T[:, :, None] * dW).transpose(2, 0, 1) @ design[None, :, :]

    return grad, M1, M2

batch_mle_update(log_alpha, global_CR_summand_1, global_CR_summand_2, global_ll_grad, lr, alpha_hat=None, prior_disp_var=None, prior_reg=False)

Perform a global dispersions update on a batch of genes.

Parameters:

Name Type Description Default
log_alpha ndarray

Current global log dispersions (n_genes).

required
global_CR_summand_1 ndarray

Global summand 1 for the CR adjustment (n_genes x n_params x n_params).

required
global_CR_summand_2 ndarray

Global summand 2 for the CR adjustment (n_genes x n_params x n_params).

required
global_ll_grad ndarray

Global gradient of the negative log likelihood (n_genes).

required
lr float

Learning rate.

required
alpha_hat ndarray

Reference dispersions (for MAP estimation, n_genes).

None
prior_disp_var float

Prior dispersion variance.

None
prior_reg bool

Whether to use prior regularization for MAP estimation (default: False).

False

Returns:

Type Description
ndarray

Updated global log dispersions (n_genes).

Source code in fedpydeseq2/core/utils/mle.py
def batch_mle_update(
    log_alpha: np.ndarray,
    global_CR_summand_1: np.ndarray,
    global_CR_summand_2: np.ndarray,
    global_ll_grad: np.ndarray,
    lr: float,
    alpha_hat: np.ndarray | None = None,
    prior_disp_var: float | None = None,
    prior_reg: bool = False,
):
    """Perform a global dispersions update on a batch of genes.

    Parameters
    ----------
    log_alpha : ndarray
        Current global log dispersions (n_genes).

    global_CR_summand_1 : ndarray
        Global summand 1 for the CR adjustment (n_genes x n_params x n_params).

    global_CR_summand_2 : ndarray
        Global summand 2 for the CR adjustment (n_genes x n_params x n_params).

    global_ll_grad : ndarray
        Global gradient of the negative log likelihood (n_genes).

    lr : float
        Learning rate.

    alpha_hat : ndarray
        Reference dispersions (for MAP estimation, n_genes).

    prior_disp_var : float
        Prior dispersion variance.

    prior_reg : bool
        Whether to use prior regularization for MAP estimation (default: ``False``).

    Returns
    -------
    ndarray
        Updated global log dispersions (n_genes).

    """
    # Add prior regularization, if required
    if prior_reg:
        global_ll_grad += (log_alpha - np.log(alpha_hat)) / prior_disp_var

    # Compute CR reg grad (not separable, cannot be computed locally)
    global_CR_grad = np.array(
        0.5
        * (np.linalg.inv(global_CR_summand_1) * global_CR_summand_2).sum(1).sum(1)
        * np.exp(log_alpha)
    )

    # Update dispersion
    global_log_alpha = log_alpha - lr * (global_ll_grad + global_CR_grad)

    return global_log_alpha

global_grid_cr_loss(nll, cr_grid)

Compute the global negative log likelihood on a grid.

Sums previously computed local negative log likelihoods and Cox-Reid adjustments.

Parameters:

Name Type Description Default
nll ndarray

Negative log likelihoods of size (n_genes x grid_length).

required
cr_grid ndarray

Summands for the Cox-Reid adjustment (n_genes x grid_length x n_params x n_params).

required

Returns:

Type Description
ndarray

Adjusted negative log likelihood (n_genes x grid_length).

Source code in fedpydeseq2/core/utils/mle.py
def global_grid_cr_loss(
    nll: np.ndarray,
    cr_grid: np.ndarray,
) -> np.ndarray:
    """Compute the global negative log likelihood on a grid.

    Sums previously computed local negative log likelihoods and Cox-Reid adjustments.

    Parameters
    ----------
    nll : ndarray
        Negative log likelihoods of size (n_genes x grid_length).

    cr_grid : ndarray
        Summands for the Cox-Reid adjustment
        (n_genes x grid_length x n_params x n_params).

    Returns
    -------
    ndarray
        Adjusted negative log likelihood (n_genes x grid_length).
    """
    if np.any(np.isnan(cr_grid)):
        n_genes, grid_length, n_params, _ = cr_grid.shape
        cr_grid = cr_grid.reshape(-1, n_params, n_params)
        mask_nan = np.any(np.isnan(cr_grid), axis=(1, 2))
        slogdet = np.zeros(n_genes * grid_length, dtype=cr_grid.dtype)
        slogdet[mask_nan] = np.nan
        if np.any(~mask_nan):
            slogdet[~mask_nan] = np.linalg.slogdet(cr_grid[~mask_nan])[1]
        return nll + 0.5 * slogdet.reshape(n_genes, grid_length)
    else:
        return nll + 0.5 * np.linalg.slogdet(cr_grid)[1]

local_grid_summands(counts, design, mu, alpha_grid)

Compute local summands of the adjusted negative log likelihood on a grid.

Includes the Cox-Reid regularization.

Parameters:

Name Type Description Default
counts ndarray

Raw counts for a set of genes (n_samples x n_genes).

required
design ndarray

Design matrix (n_samples x n_params).

required
mu ndarray

Mean estimation for the NB model (n_samples x n_genes).

required
alpha_grid ndarray

Dispersion estimates (n_genes x grid_length).

required

Returns:

Name Type Description
nll ndarray

Negative log likelihoods of size (n_genes x grid_length).

cr_matrix ndarray

Summands for the Cox-Reid adjustment (n_genes x grid_length x n_params x n_params).

Source code in fedpydeseq2/core/utils/mle.py
def local_grid_summands(
    counts: np.ndarray,
    design: np.ndarray,
    mu: np.ndarray,
    alpha_grid: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
    """Compute local summands of the adjusted negative log likelihood on a grid.

    Includes the Cox-Reid regularization.

    Parameters
    ----------
    counts : ndarray
        Raw counts for a set of genes (n_samples x n_genes).

    design : ndarray
        Design matrix (n_samples x n_params).

    mu : ndarray
        Mean estimation for the NB model (n_samples x n_genes).

    alpha_grid : ndarray
        Dispersion estimates (n_genes x grid_length).

    Returns
    -------
    nll : ndarray
        Negative log likelihoods of size (n_genes x grid_length).

    cr_matrix : ndarray
        Summands for the Cox-Reid adjustment
        (n_genes x grid_length x n_params x n_params).
    """
    # W is of size (n_samples x n_genes x grid_length)
    W = mu[:, :, None] / (1 + mu[:, :, None] * alpha_grid)
    # cr_matrix is of size (n_genes x grid_length x n_params x n_params)
    cr_matrix = (design.T[:, :, None, None] * W).transpose(2, 3, 0, 1) @ design[
        None, None, :, :
    ]
    # cr_matrix is of size (n_genes x grid_length)
    nll = grid_nb_nll(counts, mu, alpha_grid)

    return nll, cr_matrix

single_mle_grad(counts, design, mu, alpha)

Estimate the local gradients of a negative binomial GLM wrt dispersions.

Returns both the gradient of the negative likelihood, and two matrices used to compute the gradient of the Cox-Reid adjustment.

Parameters:

Name Type Description Default
counts ndarray

Raw counts for a given gene (n_samples).

required
design ndarray

Design matrix (n_samples x n_params).

required
mu ndarray

Mean estimation for the NB model (n_samples).

required
alpha float

Initial dispersion estimate (1).

required

Returns:

Name Type Description
grad ndarray

Gradient of the negative log likelihood of the observations counts following :math:NB(\\mu, \\alpha) (1).

M1 ndarray

First summand for the gradient of the CR adjustment (n_params x n_params).

M2 ndarray

Second summand for the gradient of the CR adjustment (n_params x n_params).

Source code in fedpydeseq2/core/utils/mle.py
def single_mle_grad(
    counts: np.ndarray, design: np.ndarray, mu: np.ndarray, alpha: float
) -> tuple[float, np.ndarray, np.ndarray]:
    r"""Estimate the local gradients of a negative binomial GLM wrt dispersions.

    Returns both the gradient of the negative likelihood, and two matrices used to
    compute the gradient of the Cox-Reid adjustment.


    Parameters
    ----------
    counts : ndarray
        Raw counts for a given gene (n_samples).

    design : ndarray
        Design matrix (n_samples x n_params).

    mu : ndarray
        Mean estimation for the NB model (n_samples).

    alpha : float
        Initial dispersion estimate (1).

    Returns
    -------
    grad : ndarray
        Gradient of the negative log likelihood of the observations counts following
        :math:`NB(\\mu, \\alpha)` (1).

    M1 : ndarray
        First summand for the gradient of the CR adjustment (n_params x n_params).

    M2 : ndarray
        Second summand for the gradient of the CR adjustment (n_params x n_params).
    """
    grad = alpha * dnb_nll(counts, mu, alpha)
    W = mu / (1 + mu * alpha)
    dW = -(W**2)
    M1 = (design.T * W) @ design
    M2 = (design.T * dW) @ design

    return grad, M1, M2

vec_loss(counts, design, mu, alpha, cr_reg=True, prior_reg=False, alpha_hat=None, prior_disp_var=None)

Compute the adjusted negative log likelihood of a batch of genes.

Includes Cox-Reid regularization and (optionally) prior regularization.

Parameters:

Name Type Description Default
counts ndarray

Raw counts for a set of genes (n_samples x n_genes).

required
design ndarray

Design matrix (n_samples x n_params).

required
mu ndarray

Mean estimation for the NB model (n_samples x n_genes).

required
alpha ndarray

Dispersion estimates (n_genes).

required
cr_reg bool

Whether to include Cox-Reid regularization (default: True).

True
prior_reg bool

Whether to include prior regularization (default: False).

False
alpha_hat ndarray

Reference dispersions (for MAP estimation, n_genes).

None
prior_disp_var float

Prior dispersion variance.

None

Returns:

Type Description
ndarray

Adjusted negative log likelihood (n_genes).

Source code in fedpydeseq2/core/utils/mle.py
def vec_loss(
    counts: np.ndarray,
    design: np.ndarray,
    mu: np.ndarray,
    alpha: np.ndarray,
    cr_reg: bool = True,
    prior_reg: bool = False,
    alpha_hat: np.ndarray | None = None,
    prior_disp_var: float | None = None,
) -> np.ndarray:
    """Compute the adjusted negative log likelihood of a batch of genes.

    Includes Cox-Reid regularization and (optionally) prior regularization.

    Parameters
    ----------
    counts : ndarray
        Raw counts for a set of genes (n_samples x n_genes).

    design : ndarray
        Design matrix (n_samples x n_params).

    mu : ndarray
        Mean estimation for the NB model (n_samples x n_genes).

    alpha : ndarray
        Dispersion estimates (n_genes).

    cr_reg : bool
        Whether to include Cox-Reid regularization (default: True).

    prior_reg : bool
        Whether to include prior regularization (default: False).

    alpha_hat : ndarray, optional
        Reference dispersions (for MAP estimation, n_genes).

    prior_disp_var : float, optional
        Prior dispersion variance.

    Returns
    -------
    ndarray
        Adjusted negative log likelihood (n_genes).
    """
    # closure to be minimized
    reg = 0
    if cr_reg:
        W = mu / (1 + mu * alpha)
        reg += (
            0.5
            * np.linalg.slogdet((design.T[:, :, None] * W).transpose(2, 0, 1) @ design)[
                1
            ]
        )
    if prior_reg:
        if prior_disp_var is None:
            raise ValueError("Sigma_prior is required for prior regularization")
        reg += (np.log(alpha) - np.log(alpha_hat)) ** 2 / (2 * prior_disp_var)
    return nb_nll(counts, mu, alpha) + reg

negative_binomial

Gradients and loss functions for the negative binomial distribution.

grid_nb_nll(counts, mu, alpha_grid, mask_nan=None)

Neg log-likelihood of a negative binomial, batched wrt genes on a grid.

Parameters:

Name Type Description Default
counts ndarray

Observations, n_samples x n_genes.

required
mu ndarray

Mean estimation for the NB model (n_samples x n_genes).

required
alpha_grid ndarray

Dispersions (n_genes x grid_length).

required
mask_nan ndarray

Mask for the values of the grid where mu should have taken values >> 1.

None

Returns:

Type Description
ndarray

Negative log likelihoods of size (n_genes x grid_length).

Source code in fedpydeseq2/core/utils/negative_binomial.py
def grid_nb_nll(
    counts: np.ndarray,
    mu: np.ndarray,
    alpha_grid: np.ndarray,
    mask_nan: np.ndarray | None = None,
) -> np.ndarray:
    r"""Neg log-likelihood of a negative binomial, batched wrt genes on a grid.

    Parameters
    ----------
    counts : ndarray
        Observations, n_samples x n_genes.

    mu : ndarray
        Mean estimation for the NB model (n_samples x n_genes).

    alpha_grid : ndarray
        Dispersions (n_genes x grid_length).

    mask_nan : ndarray
        Mask for the values of the grid where mu should have taken values >> 1.

    Returns
    -------
    ndarray
        Negative log likelihoods of size (n_genes x grid_length).
    """
    n = len(counts)
    alpha_neg1 = 1 / alpha_grid
    ndim_alpha = alpha_grid.ndim
    extra_dims_counts = tuple(range(2, 2 + ndim_alpha - 1))
    expanded_counts = np.expand_dims(counts, axis=extra_dims_counts)
    # In order to avoid infinities, we replace all big values in the mu with 1 and
    # modify the final quantity with their true value for the inputs were mu should have
    # taken values >> 1
    if mask_nan is not None:
        mu[mask_nan] = 1.0
    expanded_mu = np.expand_dims(mu, axis=extra_dims_counts)
    logbinom = (
        gammaln(expanded_counts + alpha_neg1[None, :])
        - gammaln(expanded_counts + 1)
        - gammaln(alpha_neg1[None, :])
    )

    nll = n * alpha_neg1 * np.log(alpha_grid) + (
        -logbinom
        + (expanded_counts + alpha_neg1) * np.log(alpha_neg1 + expanded_mu)
        - expanded_counts * np.log(expanded_mu)
    ).sum(0)
    if mask_nan is not None:
        nll[mask_nan.sum(0) > 0] = np.nan
    return nll

mu_grid_nb_nll(counts, mu_grid, alpha)

Compute the neg log-likelihood of a negative binomial.

This function is batched wrt genes on a mu grid.

Parameters:

Name Type Description Default
counts ndarray

Observations, (n_obs, batch_size).

required
mu_grid ndarray

Means of the distribution :math:\\mu, (n_mu, batch_size, n_obs).

required
alpha ndarray

Dispersions of the distribution :math:\\alpha, s.t. the variance is :math:\\mu + \\alpha \\mu^2, of size (batch_size,).

required

Returns:

Type Description
ndarray

Negative log likelihoods of the observations counts following :math:NB(\\mu, \\alpha), of size (n_mu, batch_size).

Notes

[1] https://en.wikipedia.org/wiki/Negative_binomial_distribution

Source code in fedpydeseq2/core/utils/negative_binomial.py
def mu_grid_nb_nll(
    counts: np.ndarray, mu_grid: np.ndarray, alpha: np.ndarray
) -> np.ndarray:
    r"""Compute the neg log-likelihood of a negative binomial.

    This function is *batched* wrt genes on a mu grid.

    Parameters
    ----------
    counts : ndarray
        Observations, (n_obs, batch_size).

    mu_grid : ndarray
        Means of the distribution :math:`\\mu`, (n_mu, batch_size, n_obs).

    alpha : ndarray
        Dispersions of the distribution :math:`\\alpha`,
        s.t. the variance is :math:`\\mu + \\alpha \\mu^2`,
        of size (batch_size,).

    Returns
    -------
    ndarray
        Negative log likelihoods of the observations counts
        following :math:`NB(\\mu, \\alpha)`, of size (n_mu, batch_size).

    Notes
    -----
    [1] https://en.wikipedia.org/wiki/Negative_binomial_distribution
    """
    n = len(counts)
    alpha_neg1 = 1 / alpha  # shape (batch_size,)
    logbinom = np.expand_dims(
        (
            gammaln(counts.T + alpha_neg1[:, None])
            - gammaln(counts.T + 1)
            - gammaln(alpha_neg1[:, None])
        ),
        axis=0,
    )  # Of size (1, batch_size, n_obs)
    first_term = np.expand_dims(
        n * alpha_neg1 * np.log(alpha), axis=0
    )  # Of size (1, batch_size)
    second_term = np.expand_dims(
        counts.T + np.expand_dims(alpha_neg1, axis=1), axis=0
    ) * np.log(
        np.expand_dims(alpha_neg1, axis=(0, 2)) + mu_grid
    )  # Of size (n_mu, batch_size, n_obs)
    third_term = -np.expand_dims(counts.T, axis=0) * np.log(
        mu_grid
    )  # Of size (n_mu, batch_size, n_obs)
    return first_term + (-logbinom + second_term + third_term).sum(axis=2)

vec_nb_nll_grad(counts, mu, alpha)

Return the gradient of the negative log-likelihood of a negative binomial.

Vectorized version (wrt genes).

Parameters:

Name Type Description Default
counts ndarray

Observations, n_samples x n_genes.

required
mu ndarray

Mean of the distribution.

required
alpha Series

Dispersion of the distribution, s.t. the variance is :math:\\mu + \\alpha_grid * \\mu^2.

required

Returns:

Type Description
ndarray

Gradient of the negative log likelihood of the observations counts following :math:NB(\\mu, \\alpha_grid).

Source code in fedpydeseq2/core/utils/negative_binomial.py
def vec_nb_nll_grad(
    counts: np.ndarray, mu: np.ndarray, alpha: np.ndarray
) -> np.ndarray:
    r"""Return the gradient of the negative log-likelihood of a negative binomial.

    Vectorized version (wrt genes).

    Parameters
    ----------
    counts : ndarray
        Observations, n_samples x n_genes.

    mu : ndarray
        Mean of the distribution.

    alpha : pd.Series
        Dispersion of the distribution, s.t. the variance is
        :math:`\\mu + \\alpha_grid * \\mu^2`.

    Returns
    -------
    ndarray
        Gradient of the negative log likelihood of the observations counts following
        :math:`NB(\\mu, \\alpha_grid)`.
    """
    alpha_neg1 = 1 / alpha
    ll_part = alpha_neg1**2 * (
        polygamma(0, alpha_neg1[None, :])
        - polygamma(0, counts + alpha_neg1[None, :])
        + np.log(1 + mu * alpha[None, :])
        + (counts - mu) / (mu + alpha_neg1[None, :])
    ).sum(0)

    return -ll_part

pass_on_results

Module to implement the passing of the first shared state.

TODO remove after all savings have been factored out, if not needed anymore.

AggPassOnResults

Mixin to pass on the first shared state.

Source code in fedpydeseq2/core/utils/pass_on_results.py
class AggPassOnResults:
    """Mixin to pass on the first shared state."""

    results: dict | None

    @remote
    @log_remote
    def pass_on_results(self, shared_states: list[dict]) -> dict:
        """Pass on the shared state.

        This method simply returns the first shared state.

        Parameters
        ----------
        shared_states : list
            List of shared states.

        Returns
        -------
        dict : The first shared state.

        """
        results = shared_states[0]
        # This is an ugly way to save the results for the simulation mode.
        # In simulation mode, we will look at the results attribute of the class
        # to get the results.
        # In the real mode, we will download the last shared state.
        self.results = results
        return results

pass_on_results(shared_states)

Pass on the shared state.

This method simply returns the first shared state.

Parameters:

Name Type Description Default
shared_states list

List of shared states.

required

Returns:

Name Type Description
dict The first shared state.
Source code in fedpydeseq2/core/utils/pass_on_results.py
@remote
@log_remote
def pass_on_results(self, shared_states: list[dict]) -> dict:
    """Pass on the shared state.

    This method simply returns the first shared state.

    Parameters
    ----------
    shared_states : list
        List of shared states.

    Returns
    -------
    dict : The first shared state.

    """
    results = shared_states[0]
    # This is an ugly way to save the results for the simulation mode.
    # In simulation mode, we will look at the results attribute of the class
    # to get the results.
    # In the real mode, we will download the last shared state.
    self.results = results
    return results

pipe_steps

aggregation_step(aggregation_method, train_data_nodes, aggregation_node, input_shared_states, round_idx, description='', clean_models=True, method_params=None)

Perform an aggregation step of the federated learning strategy.

Used as a wrapper to execute an aggregation method on the data of each organization.

Parameters:

Name Type Description Default
aggregation_method Callable

Method to be executed on the shared states.

required
train_data_nodes list

List of TrainDataNode.

required
aggregation_node AggregationNode

Aggregation node.

required
input_shared_states list

List of shared states to be aggregated.

required
round_idx int

Round index.

required
description str

Description of the algorithm.

''
clean_models bool

Whether to clean the models after the computation.

True
method_params dict

Optional keyword arguments to be passed to the aggregation method.

None

Returns:

Name Type Description
SharedStateRef

A shared state containing the results of the aggregation.

round_idx int

Round index incremented by 1

Source code in fedpydeseq2/core/utils/pipe_steps.py
def aggregation_step(
    aggregation_method: Callable,
    train_data_nodes: list[TrainDataNode],
    aggregation_node: AggregationNode,
    input_shared_states: list[SharedStateRef],
    round_idx: int,
    description: str = "",
    clean_models: bool = True,
    method_params: dict | None = None,
) -> tuple[SharedStateRef, int]:
    """Perform an aggregation step of the federated learning strategy.

    Used as a wrapper to execute an aggregation method on the data of each organization.

    Parameters
    ----------
    aggregation_method : Callable
        Method to be executed on the shared states.
    train_data_nodes : list
        List of TrainDataNode.
    aggregation_node : AggregationNode
        Aggregation node.
    input_shared_states : list
        List of shared states to be aggregated.
    round_idx : int
        Round index.
    description:  str
        Description of the algorithm.
    clean_models : bool
        Whether to clean the models after the computation.
    method_params : dict, optional
        Optional keyword arguments to be passed to the aggregation method.

    Returns
    -------
    SharedStateRef
        A shared state containing the results of the aggregation.
    round_idx : int
        Round index incremented by 1
    """
    method_params = method_params or {}
    share_state = aggregation_node.update_states(
        aggregation_method(
            shared_states=input_shared_states,
            _algo_name=description,
            **method_params,
        ),
        round_idx=round_idx,
        authorized_ids={
            train_data_node.organization_id for train_data_node in train_data_nodes
        },
        clean_models=clean_models,
    )
    round_idx += 1
    return share_state, round_idx

local_step(local_method, train_data_nodes, output_local_states, round_idx, input_local_states=None, input_shared_state=None, aggregation_id=None, description='', clean_models=True, method_params=None)

Local step of the federated learning strategy.

Used as a wrapper to execute a local method on the data of each organization.

Parameters:

Name Type Description Default
local_method Callable

Method to be executed on the local data.

required
train_data_nodes TrainDataNode

List of TrainDataNode.

required
output_local_states dict

Dictionary of local states to be updated.

required
round_idx int

Round index.

required
input_local_states dict

Dictionary of local states to be used as input.

None
input_shared_state SharedStateRef

Shared state to be used as input.

None
aggregation_id str

Aggregation node id.

None
description str

Description of the algorithm.

''
clean_models bool

Whether to clean the models after the computation.

True
method_params dict

Optional keyword arguments to be passed to the local method.

None

Returns:

Name Type Description
output_local_states dict

Local states containing the results of the local method, to keep within the training nodes.

output_shared_states list

Shared states containing the results of the local method, to be sent to the aggregation node.

round_idx int

Round index incremented by 1

Source code in fedpydeseq2/core/utils/pipe_steps.py
def local_step(
    local_method: Callable,
    train_data_nodes: list[TrainDataNode],
    output_local_states: dict[str, LocalStateRef],
    round_idx: int,
    input_local_states: dict[str, LocalStateRef] | None = None,
    input_shared_state: SharedStateRef | None = None,
    aggregation_id: str | None = None,
    description: str = "",
    clean_models: bool = True,
    method_params: dict | None = None,
) -> tuple[dict[str, LocalStateRef], list[SharedStateRef], int]:
    """Local step of the federated learning strategy.

    Used as a wrapper to execute a local method on the data of each organization.

    Parameters
    ----------
    local_method : Callable
        Method to be executed on the local data.
    train_data_nodes : TrainDataNode
        List of TrainDataNode.
    output_local_states : dict
        Dictionary of local states to be updated.
    round_idx : int
        Round index.
    input_local_states : dict, optional
        Dictionary of local states to be used as input.
    input_shared_state : SharedStateRef, optional
        Shared state to be used as input.
    aggregation_id : str, optional
        Aggregation node id.
    description : str
        Description of the algorithm.
    clean_models : bool
        Whether to clean the models after the computation.
    method_params : dict, optional
        Optional keyword arguments to be passed to the local method.

    Returns
    -------
    output_local_states : dict
        Local states containing the results of the local method,
        to keep within the training nodes.
    output_shared_states : list
        Shared states containing the results of the local method,
         to be sent to the aggregation node.
    round_idx : int
        Round index incremented by 1
    """
    output_shared_states = []
    method_params = method_params or {}

    for node in train_data_nodes:
        next_local_state, next_shared_state = node.update_states(
            local_method(
                node.data_sample_keys,
                shared_state=input_shared_state,
                _algo_name=description,
                **method_params,
            ),
            local_state=(
                input_local_states[node.organization_id] if input_local_states else None
            ),
            round_idx=round_idx,
            authorized_ids={node.organization_id},
            aggregation_id=aggregation_id,
            clean_models=clean_models,
        )

        output_local_states[node.organization_id] = next_local_state
        output_shared_states.append(next_shared_state)

    round_idx += 1
    return output_local_states, output_shared_states, round_idx

stat_utils

build_contrast(design_factors, design_columns, continuous_factors=None, contrast=None)

Check the validity of the contrast (if provided).

If not, build a default contrast, corresponding to the last column of the design matrix. A contrast should be a list of three strings, in the following format: ['variable_of_interest', 'tested_level', 'reference_level']. Names must correspond to the metadata data passed to the FedCenters. E.g., ['condition', 'B', 'A'] will measure the LFC of 'condition B' compared to 'condition A'. For continuous variables, the last two strings will be left empty, e.g. ``['measurement', '', '']. If None, the last variable from the design matrix is chosen as the variable of interest, and the reference level is picked alphabetically.

Parameters:

Name Type Description Default
design_factors list

The design factors.

required
design_columns list

The names of the columns of the design matrices in the centers.

required
continuous_factors list

The continuous factors in the design, if any. (default: None).

None
contrast list

A list of three strings, in the following format: ['variable_of_interest', 'tested_level', 'reference_level']. (default: None).

None
Source code in fedpydeseq2/core/utils/stat_utils.py
def build_contrast(
    design_factors,
    design_columns,
    continuous_factors=None,
    contrast: list[str] | None = None,
) -> list[str]:
    """Check the validity of the contrast (if provided).

    If not, build a default
    contrast, corresponding to the last column of the design matrix.
    A contrast should be a list of three strings, in the following format:
    ``['variable_of_interest', 'tested_level', 'reference_level']``.
    Names must correspond to the metadata data passed to the FedCenters.
    E.g., ``['condition', 'B', 'A']`` will measure the LFC of 'condition B'
    compared to 'condition A'.
    For continuous variables, the last two strings will be left empty, e.g.
    ``['measurement', '', ''].
    If None, the last variable from the design matrix
    is chosen as the variable of interest, and the reference level is picked
    alphabetically.

    Parameters
    ----------
    design_factors : list
        The design factors.
    design_columns : list
        The names of the columns of the design matrices in the centers.
    continuous_factors : list, optional
        The continuous factors in the design, if any. (default: ``None``).
    contrast : list, optional
        A list of three strings, in the following format:
        ``['variable_of_interest', 'tested_level', 'reference_level']``.
        (default: ``None``).
    """
    if contrast is not None:  # Test contrast if provided
        if len(contrast) != 3:
            raise ValueError("The contrast should contain three strings.")
        if contrast[0] not in design_factors:
            raise KeyError(
                f"The contrast variable ('{contrast[0]}') should be one "
                f"of the design factors."
            )
        # TODO: Ideally, we should check that the levels are valid. This might leak
        # data from the centers, though.

    else:  # Build contrast if None
        factor = design_factors[-1]
        # Check whether this factor is categorical or continuous.
        if continuous_factors is not None and factor in continuous_factors:
            # The factor is continuous
            contrast = [factor, "", ""]
        else:
            # The factor is categorical
            factor_col = next(col for col in design_columns if col.startswith(factor))
            split_col = factor_col.split("_")
            contrast = [split_col[0], split_col[1], split_col[-1]]

    return contrast

build_contrast_vector(contrast, LFC_columns)

Build a vector corresponding to the desired contrast.

Allows to test any pair of levels without refitting LFCs.

Parameters:

Name Type Description Default
contrast list

A list of three strings, in the following format: ['variable_of_interest', 'tested_level', 'reference_level'].

required
LFC_columns list

The names of the columns of the LFC matrices in the centers.

required

Returns:

Name Type Description
contrast_vector ndarray

The contrast vector, containing multipliers to apply to the LFCs.

contrast_idx (int, optional)

The index of the tested contrast in the LFC matrix.

Source code in fedpydeseq2/core/utils/stat_utils.py
def build_contrast_vector(contrast, LFC_columns) -> tuple[np.ndarray, int | None]:
    """
    Build a vector corresponding to the desired contrast.

    Allows to test any pair of levels without refitting LFCs.

    Parameters
    ----------
    contrast : list
        A list of three strings, in the following format:
        ``['variable_of_interest', 'tested_level', 'reference_level']``.
    LFC_columns : list
        The names of the columns of the LFC matrices in the centers.

    Returns
    -------
    contrast_vector : ndarray
        The contrast vector, containing multipliers to apply to the LFCs.
    contrast_idx : int, optional
        The index of the tested contrast in the LFC matrix.
    """
    factor = contrast[0]
    alternative = contrast[1]
    ref = contrast[2]
    if ref == alternative == "":
        # "factor" is a continuous variable
        contrast_level = factor
    else:
        contrast_level = f"{factor}_{alternative}_vs_{ref}"

    contrast_vector = np.zeros(len(LFC_columns))
    if contrast_level in LFC_columns:
        contrast_idx = LFC_columns.get_loc(contrast_level)
        contrast_vector[contrast_idx] = 1
    elif f"{factor}_{ref}_vs_{alternative}" in LFC_columns:
        # Reference and alternative are inverted
        contrast_idx = LFC_columns.get_loc(f"{factor}_{ref}_vs_{alternative}")
        contrast_vector[contrast_idx] = -1
    else:
        # Need to change reference
        # Get any column corresponding to the desired factor and extract old ref
        old_ref = next(col for col in LFC_columns if col.startswith(factor)).split(
            "_vs_"
        )[-1]
        new_alternative_idx = LFC_columns.get_loc(
            f"{factor}_{alternative}_vs_{old_ref}"
        )
        new_ref_idx = LFC_columns.get_loc(f"{factor}_{ref}_vs_{old_ref}")
        contrast_vector[new_alternative_idx] = 1
        contrast_vector[new_ref_idx] = -1
        # In that case there is no contrast index
        contrast_idx = None

    return contrast_vector, contrast_idx

wald_test(M, lfc, ridge_factor, contrast_vector, lfc_null, alt_hypothesis)

Run Wald test for a single gene.

Computes Wald statistics, standard error and p-values from dispersion and LFC estimates.

Parameters:

Name Type Description Default
M ndarray

Central parameter in the covariance matrix estimator.

required
lfc ndarray

Log-fold change estimate (in natural log scale).

required
ridge_factor ndarray

Regularization factors.

required
contrast_vector ndarray

Vector encoding the contrast that is being tested.

required
lfc_null float

The log fold change (in natural log scale) under the null hypothesis.

required
alt_hypothesis str

The alternative hypothesis for computing wald p-values.

required

Returns:

Name Type Description
wald_p_value float

Estimated p-value.

wald_statistic float

Wald statistic.

wald_se float

Standard error of the Wald statistic.

Source code in fedpydeseq2/core/utils/stat_utils.py
def wald_test(
    M: np.ndarray,
    lfc: np.ndarray,
    ridge_factor: np.ndarray | None,
    contrast_vector: np.ndarray,
    lfc_null: float,
    alt_hypothesis: Literal["greaterAbs", "lessAbs", "greater", "less"] | None,
) -> tuple[float, float, float]:
    """Run Wald test for a single gene.

    Computes Wald statistics, standard error and p-values from
    dispersion and LFC estimates.

    Parameters
    ----------
    M : ndarray
        Central parameter in the covariance matrix estimator.

    lfc : ndarray
        Log-fold change estimate (in natural log scale).

    ridge_factor : ndarray, optional
        Regularization factors.

    contrast_vector : ndarray
        Vector encoding the contrast that is being tested.

    lfc_null : float
        The log fold change (in natural log scale) under the null hypothesis.

    alt_hypothesis : str, optional
        The alternative hypothesis for computing wald p-values.

    Returns
    -------
    wald_p_value : float
        Estimated p-value.

    wald_statistic : float
        Wald statistic.

    wald_se : float
        Standard error of the Wald statistic.
    """
    # Build covariance matrix estimator

    if ridge_factor is None:
        ridge_factor = np.diag(np.repeat(1e-6, M.shape[0]))
    H = np.linalg.inv(M + ridge_factor)
    Hc = H @ contrast_vector
    # Evaluate standard error and Wald statistic
    wald_se: float = np.sqrt(Hc.T @ M @ Hc)

    def greater(lfc_null):
        stat = contrast_vector @ np.fmax((lfc - lfc_null) / wald_se, 0)
        pval = norm.sf(stat)
        return stat, pval

    def less(lfc_null):
        stat = contrast_vector @ np.fmin((lfc - lfc_null) / wald_se, 0)
        pval = norm.sf(np.abs(stat))
        return stat, pval

    def greater_abs(lfc_null):
        stat = contrast_vector @ (
            np.sign(lfc) * np.fmax((np.abs(lfc) - lfc_null) / wald_se, 0)
        )
        pval = 2 * norm.sf(np.abs(stat))  # Only case where the test is two-tailed
        return stat, pval

    def less_abs(lfc_null):
        stat_above, pval_above = greater(-abs(lfc_null))
        stat_below, pval_below = less(abs(lfc_null))
        return min(stat_above, stat_below, key=abs), max(pval_above, pval_below)

    wald_statistic: float
    wald_p_value: float
    if alt_hypothesis:
        wald_statistic, wald_p_value = {
            "greaterAbs": greater_abs(lfc_null),
            "lessAbs": less_abs(lfc_null),
            "greater": greater(lfc_null),
            "less": less(lfc_null),
        }[alt_hypothesis]
    else:
        wald_statistic = contrast_vector @ (lfc - lfc_null) / wald_se
        wald_p_value = 2 * norm.sf(np.abs(wald_statistic))

    return wald_p_value, wald_statistic, wald_se