Source code for plismbench.engine.cli
"""A module containing CLI commands of the repository."""
from __future__ import annotations
from pathlib import Path
from typing import Annotated, Union
import typer
from huggingface_hub import login, snapshot_download
from loguru import logger
from plismbench.engine.evaluate import compute_metrics
from plismbench.engine.extract.core import run_extract
from plismbench.models import FeatureExtractorsEnum
from plismbench.models.utils import DEFAULT_DEVICE
app = typer.Typer(name="plismbench")
[docs]
@app.command()
def extract(
extractor: Annotated[
str,
typer.Option(
"--extractor",
help="The name of the feature extractor as defined in ``plismbench.models.__init__.py``",
),
],
export_dir: Annotated[
Path,
typer.Option(
"--export-dir",
help=(
"The root folder where features will be stored."
" The final export directory is ``export_dir / extractor``"
),
),
],
streaming: Annotated[
bool,
typer.Option(
"--streaming",
help="Whether to stream images instead of storing to disk (300Go).",
),
] = False,
download_dir: Annotated[
Union[Path, None],
typer.Option(
"--download-dir",
help="Folder containing the .h5 files downloaded from Hugging Face.",
),
] = None,
device: Annotated[
int, typer.Option("--device", help="The CUDA devnumber or -1 for CPU.")
] = DEFAULT_DEVICE,
batch_size: Annotated[
int, typer.Option("--batch-size", help="Features extraction batch size.")
] = 32,
workers: Annotated[
int, typer.Option("--workers", help="Number of workers for async loading.")
] = 8,
overwrite: Annotated[
bool,
typer.Option(
"--overwrite",
help="Whether to overwrite the previous features extraction run.",
),
] = False,
):
"""Perform features extraction on PLISM histology tiles dataset streamed from Hugging-Face.
.. code-block:: console
$ plismbench extract --extractor h0_mini --batch-size 32 --export-dir $HOME/tmp/features/ --download-dir $HOME/tmp/slides/
"""
supported_feature_extractors = FeatureExtractorsEnum.choices()
if extractor not in supported_feature_extractors:
raise NotImplementedError(
f"Extractor {extractor} not supported."
f" Supported extractors are: {supported_feature_extractors}."
)
run_extract(
feature_extractor_name=extractor,
export_dir=export_dir / extractor,
download_dir=download_dir,
device=device,
batch_size=batch_size,
workers=workers,
overwrite=overwrite,
streaming=streaming,
)
[docs]
@app.command()
def download(
download_dir: Annotated[
Path,
typer.Option(
"--download-dir",
help="Folder containing the .h5 files downloaded from Hugging Face.",
),
],
hf_token: Annotated[str, typer.Option("--token", help="Hugging Face token.")],
workers: Annotated[
int, typer.Option("--workers", help="Number of workers for parallel download.")
] = 8,
):
"""Download PLISM dataset from Hugging Face."""
login(token=hf_token, new_session=False)
_ = snapshot_download(
repo_id="owkin/plism-dataset",
repo_type="dataset",
local_dir=download_dir,
allow_patterns=["*_to_GMH_S60.tif.h5"],
ignore_patterns=[".gitattribues"],
max_workers=workers,
)
[docs]
@app.command()
def evaluate(
extractor: Annotated[
str,
typer.Option(
"--extractor",
help="The name of the feature extractor as defined in ``plismbench.models.__init__.py``",
),
],
features_dir: Annotated[
Path,
typer.Option(
"--features-dir",
help=(
"The root folder where features will be stored."
" The final export directory is ``export_dir / extractor``."
),
),
],
metrics_dir: Annotated[
Path,
typer.Option(
"--metrics-dir",
help=(
"Folder containing the output metrics."
" The final export directory is ``metrics_dir / extractor``."
),
),
],
n_tiles: Annotated[
Union[str, None],
typer.Option(
"--n-tiles", help="Number of tiles per slide for metrics computation."
),
] = None,
top_k: Annotated[
Union[str, None],
typer.Option("--top-k", help="Values of k for top-k accuracy computation."),
] = None,
device: Annotated[
str,
typer.Option(
"--device", help="'cpu' (parallel computation) or 'gpu' (sequential)."
),
] = "gpu",
workers: Annotated[
int,
typer.Option(
"--workers", help="Number of workers for cpu parallel computations."
),
] = 4,
overwrite: Annotated[
bool,
typer.Option(
"--overwrite",
help="Whether to overwrite existing metrics.",
),
] = False,
):
"""Compute robustness metrics for a list of feature extractors."""
logger.info(f"Computing metrics for extractor {extractor}.")
_ = compute_metrics(
features_root_dir=features_dir,
metrics_save_dir=metrics_dir,
extractor=extractor,
top_k=top_k if top_k is None else [int(t) for t in top_k.split(" ")],
n_tiles=int(n_tiles) if n_tiles is not None else n_tiles,
device=device,
overwrite=overwrite,
workers=workers,
)
if __name__ == "__main__":
app()