Source code for plismbench.metrics.base

"""Module for base metric object."""

from abc import abstractmethod

from loguru import logger


try:
    import cupy as cp
except ImportError as error:
    logger.error(
        f"cupy is not installed. Please run `make install-cupy`.\nError: {error}."
    )
import numpy as np


[docs] class BasePlismMetric: """Base class for metrics. Attributes ---------- device: str: Literal["cpu", "gpu"] Device to use for computation. """ def __init__(self, device: str, use_mixed_precision: bool = True): self.device = device self.ncp = cp if device == "gpu" else np self.use_mixed_precision = use_mixed_precision
[docs] @abstractmethod def compute_metric(self, matrix_a: np.ndarray, matrix_b: np.ndarray): """Compute metric between feature matrices A and B.""" raise NotImplementedError