mmlearn.modules.layers.normalization

Normalization layers.

Classes

L2Norm

L2 normalization.

class L2Norm(dim)[source]

L2 normalization.

Parameters:

dim (int) – The dimension along which to normalize.

forward(x)[source]

Apply L2 normalization to the input tensor.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_sz, seq_len, dim).

Returns:

Normalized tensor of shape (batch_sz, seq_len, dim).

Return type:

torch.Tensor