Source code for fl4health.model_bases.feature_extractor_buffer

from collections.abc import Callable
from logging import INFO

import torch
import torch.nn as nn
from flwr.common.logger import log
from torch.utils.hooks import RemovableHandle


[docs] class FeatureExtractorBuffer:
[docs] def __init__(self, model: nn.Module, flatten_feature_extraction_layers: dict[str, bool]) -> None: """ This class is used to extract features from the intermediate layers of a neural network model and store them in a buffer. The features are extracted using additional hooks that are registered to the model. The extracted features are stored in a dictionary where the keys are the layer names and the values are the extracted features as torch Tensors. Args: model (nn.Module): The neural network model. flatten_feature_extraction_layers (dict[str, bool]): Dictionary of layers to extract features from them and whether to flatten them. Keys are the layer names that are extracted from the named_modules and values are boolean. Attributes: model (nn.Module): The neural network model. flatten_feature_extraction_layers (dict[str, bool]): A dictionary specifying whether to flatten the feature extraction layers. fhooks (list[RemovableHandle]): A list to store the handles for removing hooks. accumulate_features (bool): A flag indicating whether to accumulate features. extracted_features_buffers (dict[str, list[torch.Tensor]]): A dictionary to store the extracted features for each layer. """ self.model = model self.flatten_feature_extraction_layers = flatten_feature_extraction_layers self.fhooks: list[RemovableHandle] = [] self.accumulate_features: bool = False self.extracted_features_buffers: dict[str, list[torch.Tensor]] = { layer: [] for layer in flatten_feature_extraction_layers.keys() }
[docs] def enable_accumulating_features(self) -> None: """ Enables the accumulation of features in the buffers for multiple forward passes. This method sets the `accumulate_features` flag to True, allowing the model to accumulate features in the buffers for multiple forward passes. This can be useful in scenarios where you want to extract features from intermediate layers of the model during inference. """ self.accumulate_features = True
[docs] def disable_accumulating_features(self) -> None: """ Disables the accumulation of features in the buffers. This method sets the `accumulate_features` attribute to False, which prevents the buffers from accumulating features and overwrites them for each forward pass. """ self.accumulate_features = False
[docs] def clear_buffers(self) -> None: """ Clears the extracted features buffers for all layers. """ self.extracted_features_buffers = {layer: [] for layer in self.flatten_feature_extraction_layers.keys()}
[docs] def get_hierarchical_attr(self, module: nn.Module, layer_hierarchy: list[str]) -> nn.Module: """ Traverse the hierarchical attributes of the module to get the desired attribute. Hooks should be registered to specific layers of the model, not to nn.Sequential or nn.ModuleList. Args: module (nn.Module): The nn.Module object to traverse. layer_hierarchy (list[str]): The hierarchical list of name of desired layer. Returns: nn.Module: The desired layer of the model. """ if len(layer_hierarchy) == 1: return getattr(module, layer_hierarchy[0]) else: return self.get_hierarchical_attr(getattr(module, layer_hierarchy[0]), layer_hierarchy[1:])
[docs] def find_last_common_prefix(self, prefix: str, layers_name: list[str]) -> str: """ Check the model's list of named modules to filter any layer that starts with the given prefix and return the last one. Args: prefix (str): The prefix of the layer name for registering the hook. layers_name (list[str]): The list of named modules of the model. The assumption is that list of named modules is sorted in the order of the model's forward pass with depth-first traversal. This will allow the user to specify the generic name of the layer instead of the full hierarchical name. Returns: str: The complete name of last named layer that matches the prefix. """ filtered_layers = [layer for layer in layers_name if layer.startswith(prefix)] # Return the last element that matches the criteria return filtered_layers[-1]
def _maybe_register_hooks(self) -> None: """ Checks if hooks are already registered and registers them if not. Hooks extract the intermediate feature as output of the selected layers in the model. """ if len(self.fhooks) == 0: log(INFO, "Starting to register hooks:") named_layers = list(dict(self.model.named_modules()).keys()) for layer in self.flatten_feature_extraction_layers.keys(): log(INFO, f"Registering hook for layer: {layer}") # Find the last specific layer under a given generic name specific_layer = self.find_last_common_prefix(layer, named_layers) # Split the specific layer name by '.' to get the hierarchical attribute layer_hierarchy_list = specific_layer.split(".") self.fhooks.append( self.get_hierarchical_attr(self.model, layer_hierarchy_list).register_forward_hook( self.forward_hook(layer) ) ) else: log(INFO, "Hooks already registered.")
[docs] def remove_hooks(self) -> None: """ Removes the hooks from the model for checkpointing and clears the hook list. This method is used to remove any hooks that have been added to the feature extractor buffer. It is typically called prior to checkpointing the model. """ log(INFO, "Removing hooks.") for hook in self.fhooks: hook.remove() self.fhooks.clear()
[docs] def forward_hook(self, layer_name: str) -> Callable: """ Returns a hook function that is called during the forward pass of a module. Args: layer_name (str): The name of the layer. Returns: Callable: The hook function that takes in a module, input, and output tensors. """ def hook(module: nn.Module, input: torch.Tensor, output: torch.Tensor) -> None: if not self.accumulate_features: self.extracted_features_buffers[layer_name] = [output] else: self.extracted_features_buffers[layer_name].append(output) return hook
[docs] def flatten(self, features: torch.Tensor) -> torch.Tensor: """ Flattens the input tensor along the batch dimension. The features are of shape (batch_size, *). We flatten them across the batch dimension to get a 2D tensor of shape (batch_size, feature_size). Args: features (torch.Tensor): The input tensor of shape (batch_size, *). Returns: torch.Tensor: The flattened tensor of shape (batch_size, feature_size). """ return features.reshape(len(features), -1)
[docs] def get_extracted_features(self) -> dict[str, torch.Tensor]: """ Returns a dictionary of extracted features. Returns: features (dict[str, torch.Tensor]): A dictionary where the keys are the layer names and the values are the extracted features as torch Tensors. """ features = {} for layer in self.extracted_features_buffers: features[layer] = ( self.flatten(torch.cat(self.extracted_features_buffers[layer], dim=0)) if self.flatten_feature_extraction_layers[layer] else torch.cat(self.extracted_features_buffers[layer], dim=0) ) return features