plismbench.models.utils module#

Utility functions to load and prepare feature extractors.

class plismbench.models.utils.MixedPrecisionModule(module: Module, device_type: str)[source]#

Bases: Module

Mixed Precision Module wrapper.

Parameters:
  • module (torch.nn.Module)

  • device_type (str)

forward(*args, **kwargs)[source]#

Forward pass using autocast.

plismbench.models.utils.prepare_module(module: Module, device: int | list[int] | None = None, mixed_precision: bool = True) tuple[Module, str | device][source]#

Prepare torch.nn.Module.

By:
  • setting it to eval mode

  • disabling gradients

  • moving it to the correct device(s)

Parameters:
  • module (torch.nn.Module)

  • device (Union[None, int, list[int]] = None) – Compute resources to use. If None, will use all available GPUs. If -1, extraction will run on CPU.

  • mixed_precision (bool = True) – Whether to use mixed_precision (improved throughput on modern GPU cards).

Return type:

torch.nn.Module, str | torch.device