mmlearn.modules.encoders.vision.TimmViT

class TimmViT(model_name, modality='RGB', projection_dim=768, pretrained=True, freeze_layers=False, freeze_layer_norm=True, peft_config=None, model_kwargs=None)[source]

Bases: Module

Vision Transformer model from timm.

Parameters:
  • model_name (str) – The name of the model to use.

  • modality (str, default="RGB") – The modality of the input data. This allows this model to be used with different image modalities e.g. RGB, Depth, etc.

  • projection_dim (int, default=768) – The dimension of the projection head.

  • pretrained (bool, default=True) – Whether to use the pretrained weights.

  • freeze_layers (Union[int, float, list[int], bool], default=False) – Whether to freeze the layers.

  • freeze_layer_norm (bool, default=True) – Whether to freeze the layer norm.

  • peft_config (Optional[PeftConfig], optional, default=None) – The configuration from the peft library to use to wrap the model for parameter-efficient finetuning.

  • model_kwargs (Optional[dict[str, Any]], default=None) – Additional keyword arguments for the model.

Methods

Attributes

forward(inputs)[source]

Run the forward pass.

Parameters:

inputs (dict[str, Any]) – The input data. The image will be expected under the Modalities.RGB key.

Returns:

The output of the model.

Return type:

BaseModelOutput

get_intermediate_layers(inputs, n=1)[source]

Get the output of the intermediate layers.

Parameters:
  • inputs (dict[str, Any]) – The input data. The image will be expected under the Modalities.RGB key.

  • n (int, default=1) – The number of intermediate layers to return.

Returns:

The outputs of the last n intermediate layers.

Return type:

list[torch.Tensor]

get_patch_info()[source]

Get patch size and number of patches.

Returns:

Patch size and number of patches.

Return type:

tuple[int, int]