mmlearn.modules.layers.logit_scaling¶
Learnable logit scaling layer.
Classes
Logit scaling layer. |
- class LearnableLogitScaling(init_logit_scale=14.285714285714285, max_logit_scale=100, learnable=True)[source]¶
Logit scaling layer.
- Parameters:
- forward(x)[source]¶
Apply the logit scaling to the input tensor.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(batch_sz, seq_len, dim)
.- Return type: