mmlearn.modules.encoders.vision.VisionTransformer

class VisionTransformer(modality='RGB', img_size=None, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, global_pool='', drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02)[source]

Bases: Module

Vision Transformer.

This module implements a Vision Transformer that processes images using a series of transformer blocks and patch embeddings.

Parameters:
  • modality (str, optional, default="RGB") – The modality of the input data. This allows this model to be used with different image modalities e.g. RGB, Depth, etc.

  • img_size (List[int], optional, default=None) – List of input image sizes.

  • patch_size (int, optional, default=16) – Size of each patch.

  • in_chans (int, optional, default=3) – Number of input channels.

  • embed_dim (int, optional, default=768) – Embedding dimension.

  • depth (int, optional, default=12) – Number of transformer blocks.

  • num_heads (int, optional, default=12) – Number of attention heads.

  • mlp_ratio (float, optional, default=4.0) – Ratio of hidden dimension in the MLP.

  • qkv_bias (bool, optional, default=True) – If True, add a learnable bias to the query, key, and value projections.

  • qk_scale (Optional[float], optional) – Override the default qk scale factor.

  • drop_rate (float, optional, default=0.0) – Dropout rate for the transformer blocks.

  • attn_drop_rate (float, optional, default=0.0) – Dropout rate for the attention mechanism.

  • drop_path_rate (float, optional, default=0.0) – Dropout rate for stochastic depth.

  • norm_layer (Callable[..., torch.nn.Module], optional, default=torch.nn.LayerNorm) – Normalization layer to use.

  • init_std (float, optional, default=0.02) – Standard deviation for weight initialization.

  • **kwargs (dict) – Additional keyword arguments.

Methods

Attributes

fix_init_weight()[source]

Fix initialization of weights by rescaling them according to layer depth.

Return type:

None

forward(inputs, return_hidden_states=False)[source]

Forward pass through the Vision Transformer.

Return type:

tuple[Tensor, Optional[list[Tensor]]]

interpolate_pos_encoding(x, pos_embed)[source]

Interpolate positional encoding to match the size of the input tensor.

Parameters:
Returns:

Interpolated positional encoding.

Return type:

torch.Tensor