Source code for plismbench.utils.viz

"""Visualization of robustness results across different extractors."""

from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns


sns.set_style("darkgrid")
pd.set_option("future.no_silent_downcasting", True)

# Please leave those 2 dictionnaries as is.
EXTRACTOR_LABELS_DICT = {
    "conch": "CONCH",
    "gpfm": "GPFM",
    "hibou_vit_base": "Hibou Base",
    "hibou_vit_large": "Hibou Large",
    "h0_mini": "H0-Mini",
    "hoptimus0": "H-Optimus-0",
    "kaiko_vit_base_8": "Kaiko ViT-B/8",
    "kaiko_vit_large_14": "Kaiko ViT-L/14",
    "phikon": "Phikon",
    "phikon_v2": "Phikon v2",
    "plip": "PLIP",
    "provgigapath": "Prov-GigaPath",
    "uni": "UNI",
    "uni2h": "UNI2-h",
    "virchow": "Virchow",
    "virchow2": "Virchow2",
}
EXTRACTOR_PARAMETERS_DICT = {
    "conch": 86_000_000,
    "gpfm": 307_000_000,
    "hibou_vit_base": 86_000_000,
    "hibou_vit_large": 307_000_000,
    "h0_mini": 86_000_000,
    "hoptimus0": 1_100_000_000,
    "kaiko_vit_base_8": 86_000_000,
    "kaiko_vit_large_14": 307_000_000,
    "phikon": 86_000_000,
    "phikon_v2": 307_000_000,
    "plip": 86_000_000,
    "provgigapath": 1_100_000_000,
    "uni": 307_000_000,
    "uni2h": 681_000_000,
    "virchow": 632_000_000,
    "virchow2": 632_000_000,
}


[docs] def expand_columns(raw_results: pd.DataFrame) -> pd.DataFrame: """Expand columns so as to have one column per metric and robustness type.""" output = [] # Robustness types are "all", "inter-scanner", "inter-staining", # "inter-scanner, inter-staining" for robustness_type in raw_results["robustness_type"].unique(): subset = raw_results[ raw_results["robustness_type"] == robustness_type ].sort_values("extractor") subset = subset.set_index("extractor").iloc[:, 1:] subset.columns = [f"{c}__{robustness_type}" for c in subset.columns] output.append(subset) output_df = pd.concat(output, axis=1) output_df.insert( 0, "extractor", output_df.index.to_series().replace(EXTRACTOR_LABELS_DICT) ) output_df.insert( 1, "Parameters", output_df.index.to_series().replace(EXTRACTOR_PARAMETERS_DICT) ) return output_df
[docs] def display_plism_metrics( raw_results: pd.DataFrame, metric_x: str = "cosine_similarity_median", metric_y: str = "top_1_accuracy_median", robustness_x: str = "all", robustness_y: str = "all", label_x: str = "Median Cosine Similarity", label_y: str = "Median Top-1 Accuracy", fig_save_path: str | Path | None = None, xlim: tuple[float, float] | None = None, ylim: tuple[float, float] | None = None, palette: Any | None = None, ): """Display PLISM robustness metrics. Parameters ---------- raw_results: pd.DataFrame Raw results as computed by ``plismbench.utils.metrics.format_results``. metric_x: str = "cosine_similarity_median" Metric to display for x-axis. Should be of type 'metric_aggregation'. Supported metrics depends on the columns of ``raw_results`` but are by "cosine_similarity", "top_1_accuracy", "top_3_accuracy", "top_5_accuracy" and "top_10_accuracy". Supported aggregation types are either "mean" or "median". metric_y: str = "top_1_accuracy_median" Metric to display for y-axis. robustness_x: str = "all" Type of robustness for ``metric_x``. Supported types are "all", "inter-scanner", "inter-staining", "inter-scanner" and "inter-staining". robustness_y: str = "all" Type of robustness for ``metric_y``. label_x: str = "Median Cosine Similarity" Label for x-axis (can be anything). label_y: str = "Median Top-1 Accuracy" Label for y-axis (can be anything). xlim: tuple[float, float] | None = None Limits for x-axis. ylim: tuple[float, float] | None = None Limits for y-axis. palette = None Color palette. fig_save_path: str | Path | None = None Figure save path. """ # Set figure fig, ax = plt.subplots(1, 1, figsize=(10, 6), dpi=200) # Default limits for axes xlim = (0, 1) if xlim is None else xlim ylim = (0, 1) if xlim is None else ylim # Define x and y column in ``results_df`` col_x = f"{metric_x}__{robustness_x}" col_y = f"{metric_y}__{robustness_y}" results_df = expand_columns(raw_results) results_df = results_df[["extractor", "Parameters", col_x, col_y]] results_df[[col_x, col_y]] = results_df[[col_x, col_y]].astype(float) # Default color palette if palette is None: palette = sns.color_palette("tab20")[: results_df.shape[0]] # Display metrics for each extractor sns.scatterplot( data=results_df, x=col_x, y=col_y, hue="extractor", size="Parameters", sizes=(50, 2000), palette=palette, edgecolor="black", alpha=0.9, legend=True, ax=ax, ) # Display number of parameters for each extractor sns.scatterplot( data=results_df, x=col_x, y=col_y, s=5, color="black", marker="+", facecolor="black", alpha=0.7, legend=True, ax=ax, ) # Set labels and limits ax.set_xlabel(label_x) ax.set_ylabel(label_y) ax.set_xlim(xlim) ax.set_ylim(ylim) # Please leave these values as-is for _, row in results_df.iterrows(): plt.text( row[col_x] + 0.001, row[col_y] + 0.001, row["extractor"], fontsize=10, ha="left", va="bottom", fontweight="bold", ) # Please leave these values as-is: adding circles in the legend # proportionnal to the model size. scatter_handles = [ plt.scatter( [], [], s=50, edgecolor="black", color="gray", alpha=0.5, label=" 22M" ), plt.scatter( [], [], s=155, edgecolor="black", color="gray", alpha=0.5, label=" 86M" ), plt.scatter( [], [], s=555, edgecolor="black", color="gray", alpha=0.5, label=" 307M" ), plt.scatter( [], [], s=1143, edgecolor="black", color="gray", alpha=0.5, label=" 632M", ), plt.scatter( [], [], s=2000, edgecolor="black", color="gray", alpha=0.5, label=" 1,100M", ), ] ax.legend( handles=scatter_handles, title="No. parameters", loc="center left", bbox_to_anchor=(1, 0.6), fontsize=12, labelspacing=0.2, handleheight=2, ) # Export figure if fig_save_path is not None: fig.savefig(fig_save_path, dpi=300, bbox_inches="tight") plt.show()