mmlearn.modules.encoders.vision.VisionTransformerPredictor

class VisionTransformerPredictor(num_patches=196, embed_dim=768, predictor_embed_dim=384, depth=6, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]

Bases: Module

Vision Transformer Predictor.

This module implements a Vision Transformer that predicts masked tokens using a series of transformer blocks.

Parameters:
  • num_patches (int) – The number of patches in the input image.

  • embed_dim (int, optional, default=768) – The embedding dimension.

  • predictor_embed_dim (int, optional, default=384) – The embedding dimension for the predictor.

  • depth (int, optional, default=6) – The number of transformer blocks.

  • num_heads (int, optional, default=12) – The number of attention heads.

  • mlp_ratio (float, optional, default=4.0) – Ratio of the hidden dimension in the MLP.

  • qkv_bias (bool, optional, default=True) – If True, add a learnable bias to the query, key, and value projections.

  • qk_scale (Optional[float], optional, default=None) – Override the default qk scale factor.

  • drop_rate (float, optional, default=0.0) – Dropout rate for the transformer blocks.

  • attn_drop_rate (float, optional, default=0.0) – Dropout rate for the attention mechanism.

  • drop_path_rate (float, optional, default=0.0) – Dropout rate for stochastic depth.

  • norm_layer (Callable[..., torch.nn.Module], optional, default=torch.nn.LayerNorm) – Normalization layer to use.

  • init_std (float, optional, default=0.02) – Standard deviation for weight initialization.

  • **kwargs (dict) – Additional keyword arguments.

Methods

Attributes

fix_init_weight()[source]

Fix initialization of weights by rescaling them according to layer depth.

Return type:

None

forward(x, masks_x, masks)[source]

Forward pass through the Vision Transformer Predictor.

Return type:

Tensor