fl4health.model_bases.feature_extractor_buffer module¶
- class FeatureExtractorBuffer(model, flatten_feature_extraction_layers)[source]¶
Bases:
object
- __init__(model, flatten_feature_extraction_layers)[source]¶
This class is used to extract features from the intermediate layers of a neural network model and store them in a buffer. The features are extracted using additional hooks that are registered to the model. The extracted features are stored in a dictionary where the keys are the layer names and the values are the extracted features as torch Tensors.
- Parameters:
- model¶
The neural network model.
- Type:
nn.Module
- flatten_feature_extraction_layers¶
A dictionary specifying whether to flatten the feature extraction layers.
- disable_accumulating_features()[source]¶
Disables the accumulation of features in the buffers.
This method sets the accumulate_features attribute to False, which prevents the buffers from accumulating features and overwrites them for each forward pass.
- Return type:
- enable_accumulating_features()[source]¶
Enables the accumulation of features in the buffers for multiple forward passes.
This method sets the accumulate_features flag to True, allowing the model to accumulate features in the buffers for multiple forward passes. This can be useful in scenarios where you want to extract features from intermediate layers of the model during inference.
- Return type:
- find_last_common_prefix(prefix, layers_name)[source]¶
Check the model’s list of named modules to filter any layer that starts with the given prefix and return the last one.
- Parameters:
prefix (str) – The prefix of the layer name for registering the hook.
layers_name (list[str]) – The list of named modules of the model. The assumption is that list of
This (named modules is sorted in the order of the model's forward pass with depth-first traversal.)
name. (will allow the user to specify the generic name of the layer instead of the full hierarchical)
- Returns:
The complete name of last named layer that matches the prefix.
- Return type:
- flatten(features)[source]¶
Flattens the input tensor along the batch dimension. The features are of shape (batch_size, *). We flatten them across the batch dimension to get a 2D tensor of shape (batch_size, feature_size).
- Parameters:
features (torch.Tensor) – The input tensor of shape (batch_size, *).
- Returns:
The flattened tensor of shape (batch_size, feature_size).
- Return type:
torch.Tensor
- forward_hook(layer_name)[source]¶
Returns a hook function that is called during the forward pass of a module.
- Parameters:
layer_name (str) – The name of the layer.
- Returns:
The hook function that takes in a module, input, and output tensors.
- Return type:
Callable