mmlearn.modules.layers.logit_scaling.LearnableLogitScaling¶
- class LearnableLogitScaling(init_logit_scale=14.285714285714285, max_logit_scale=100, learnable=True)[source]¶
Bases:
Module
Logit scaling layer.
- Parameters:
Methods
Attributes
- forward(x)[source]¶
Apply the logit scaling to the input tensor.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(batch_sz, seq_len, dim)
.- Return type: