Source code for plismbench.models.utils

"""Utility functions to load and prepare feature extractors."""

from __future__ import annotations

import torch


DEFAULT_DEVICE = (
    0 if (torch.cuda.is_available() or torch.backends.mps.is_available()) else -1
)


[docs] class MixedPrecisionModule(torch.nn.Module): """Mixed Precision Module wrapper. Parameters ---------- module: torch.nn.Module device_type: str """ def __init__(self, module: torch.nn.Module, device_type: str): super(MixedPrecisionModule, self).__init__() self.module = module self.device_type = device_type
[docs] def forward(self, *args, **kwargs): """Forward pass using ``autocast``.""" # Mixed precision forward with torch.amp.autocast(device_type=self.device_type): output = self.module(*args, **kwargs) if not isinstance(output, torch.Tensor): raise ValueError( "MixedPrecisionModule currently only supports models returning a single tensor." ) # Back to float32 return output.to(torch.float32)
[docs] def prepare_module( module: torch.nn.Module, device: int | list[int] | None = None, mixed_precision: bool = True, ) -> tuple[torch.nn.Module, str | torch.device]: """ Prepare torch.nn.Module. By: - setting it to eval mode - disabling gradients - moving it to the correct device(s) Parameters ---------- module: torch.nn.Module device: Union[None, int, list[int]] = None Compute resources to use. If None, will use all available GPUs. If -1, extraction will run on CPU. mixed_precision: bool = True Whether to use mixed_precision (improved throughput on modern GPU cards). Returns ------- torch.nn.Module, str | torch.device """ if mixed_precision: if not (torch.cuda.is_available() or device == -1): raise ValueError("Mixed precision in only available for CUDA GPUs and CPU.") module = MixedPrecisionModule( module, device_type="cpu" if not torch.cuda.is_available() else "cuda" ) device_: str | torch.device if device == -1 or not ( torch.cuda.is_available() or torch.backends.mps.is_available() ): device_ = "cpu" elif torch.backends.mps.is_available(): device_ = torch.device("mps") elif isinstance(device, int): device_ = f"cuda:{device}" else: # Use DataParallel to distribute the module on all GPUs device_ = "cuda:0" if device is None else f"cuda:{device[0]}" module = torch.nn.DataParallel(module, device_) # type: ignore module.to(device_) module.eval() module.requires_grad_(False) return module, device_