fl4health.model_bases.feature_extractor_buffer module¶
- class FeatureExtractorBuffer(model, flatten_feature_extraction_layers)[source]¶
Bases:
object
- __init__(model, flatten_feature_extraction_layers)[source]¶
This class is used to extract features from the intermediate layers of a neural network model and store them in a buffer. The features are extracted using additional hooks that are registered to the model. The extracted features are stored in a dictionary where the keys are the layer names and the values are the extracted features as torch Tensors.
Attributes:
model (nn.Module): The neural network model.
flatten_feature_extraction_layers (dict[str, bool]): A dictionary specifying whether to flatten the feature extraction layers.
fhooks (list[RemovableHandle]): A list to store the handles for removing hooks.
accumulate_features (bool): A flag indicating whether to accumulate features.
extracted_features_buffers (dict[str, list[torch.Tensor]]): A dictionary to store the extracted features for each layer.
- disable_accumulating_features()[source]¶
Disables the accumulation of features in the buffers.
This method sets the
accumulate_features
attribute to False, which prevents the buffers from accumulating features and overwrites them for each forward pass.- Return type:
- enable_accumulating_features()[source]¶
Enables the accumulation of features in the buffers for multiple forward passes.
This method sets the
accumulate_features
flag to True, allowing the model to accumulate features in the buffers for multiple forward passes. This can be useful in scenarios where you want to extract features from intermediate layers of the model during inference.- Return type:
- find_last_common_prefix(prefix, layers_name)[source]¶
Check the model’s list of named modules to filter any layer that starts with the given prefix and return the last one.
- Parameters:
prefix (str) – The prefix of the layer name for registering the hook.
layers_name (list[str]) – The list of named modules of the model. The assumption is that list of named modules is sorted in the order of the model’s forward pass with depth-first traversal. This will allow the user to specify the generic name of the layer instead of the full hierarchical name.
- Returns:
The complete name of last named layer that matches the prefix.
- Return type:
- flatten(features)[source]¶
Flattens the input tensor along the batch dimension. The features are of shape (
batch_size
, *). We flatten them across the batch dimension to get a 2D tensor of shape (batch_size
,feature_size
).- Parameters:
features (torch.Tensor) – The input tensor of shape (
batch_size
, *).- Returns:
The flattened tensor of shape (
batch_size
,feature_size
).- Return type:
torch.Tensor
- forward_hook(layer_name)[source]¶
Returns a hook function that is called during the forward pass of a module.
- Parameters:
layer_name (str) – The name of the layer.
- Returns:
The hook function that takes in a module, input, and output tensors.
- Return type:
Callable