fl4health.model_bases.feature_extractor_buffer module

class FeatureExtractorBuffer(model, flatten_feature_extraction_layers)[source]

Bases: object

__init__(model, flatten_feature_extraction_layers)[source]

This class is used to extract features from the intermediate layers of a neural network model and store them in a buffer. The features are extracted using additional hooks that are registered to the model. The extracted features are stored in a dictionary where the keys are the layer names and the values are the extracted features as torch Tensors.

Parameters:
  • model (nn.Module) – The neural network model.

  • flatten_feature_extraction_layers (dict[str, bool]) – Dictionary of layers to extract features from them and

  • are (whether to flatten them. Keys are the layer names that are extracted from the named_modules and values)

  • boolean.

model

The neural network model.

Type:

nn.Module

flatten_feature_extraction_layers

A dictionary specifying whether to flatten the feature extraction layers.

Type:

dict[str, bool]

fhooks

A list to store the handles for removing hooks.

Type:

list[RemovableHandle]

accumulate_features

A flag indicating whether to accumulate features.

Type:

bool

extracted_features_buffers

A dictionary to store the extracted features for each layer.

Type:

dict[str, list[torch.Tensor]]

clear_buffers()[source]

Clears the extracted features buffers for all layers.

Return type:

None

disable_accumulating_features()[source]

Disables the accumulation of features in the buffers.

This method sets the accumulate_features attribute to False, which prevents the buffers from accumulating features and overwrites them for each forward pass.

Return type:

None

enable_accumulating_features()[source]

Enables the accumulation of features in the buffers for multiple forward passes.

This method sets the accumulate_features flag to True, allowing the model to accumulate features in the buffers for multiple forward passes. This can be useful in scenarios where you want to extract features from intermediate layers of the model during inference.

Return type:

None

find_last_common_prefix(prefix, layers_name)[source]

Check the model’s list of named modules to filter any layer that starts with the given prefix and return the last one.

Parameters:
  • prefix (str) – The prefix of the layer name for registering the hook.

  • layers_name (list[str]) – The list of named modules of the model. The assumption is that list of

  • This (named modules is sorted in the order of the model's forward pass with depth-first traversal.)

  • name. (will allow the user to specify the generic name of the layer instead of the full hierarchical)

Returns:

The complete name of last named layer that matches the prefix.

Return type:

str

flatten(features)[source]

Flattens the input tensor along the batch dimension. The features are of shape (batch_size, *). We flatten them across the batch dimension to get a 2D tensor of shape (batch_size, feature_size).

Parameters:

features (torch.Tensor) – The input tensor of shape (batch_size, *).

Returns:

The flattened tensor of shape (batch_size, feature_size).

Return type:

torch.Tensor

forward_hook(layer_name)[source]

Returns a hook function that is called during the forward pass of a module.

Parameters:

layer_name (str) – The name of the layer.

Returns:

The hook function that takes in a module, input, and output tensors.

Return type:

Callable

get_extracted_features()[source]

Returns a dictionary of extracted features.

Returns:

A dictionary where the keys are the layer names and the values are

the extracted features as torch Tensors.

Return type:

features (dict[str, torch.Tensor])

get_hierarchical_attr(module, layer_hierarchy)[source]

Traverse the hierarchical attributes of the module to get the desired attribute. Hooks should be registered to specific layers of the model, not to nn.Sequential or nn.ModuleList.

Parameters:
  • module (nn.Module) – The nn.Module object to traverse.

  • layer_hierarchy (list[str]) – The hierarchical list of name of desired layer.

Returns:

The desired layer of the model.

Return type:

nn.Module

remove_hooks()[source]

Removes the hooks from the model for checkpointing and clears the hook list. This method is used to remove any hooks that have been added to the feature extractor buffer. It is typically called prior to checkpointing the model.

Return type:

None