mmlearn.modules.layers.logit_scaling

Learnable logit scaling layer.

Classes

LearnableLogitScaling

Logit scaling layer.

class LearnableLogitScaling(init_logit_scale=14.285714285714285, max_logit_scale=100, learnable=True)[source]

Logit scaling layer.

Parameters:
  • logit_scale_init (float, optional, default=1/0.07) – Initial value of the logit scale.

  • learnable (bool, optional, default=True) – If True, the logit scale is learnable. Otherwise, it is fixed.

  • max_logit_scale (float, optional, default=100) – Maximum value of the logit scale.

extra_repr()[source]

Return the string representation of the layer.

Return type:

str

forward(x)[source]

Apply the logit scaling to the input tensor.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_sz, seq_len, dim).

Return type:

Tensor