mmlearn.modules.layers.normalization.L2Norm

class L2Norm(dim)[source]

Bases: Module

L2 normalization.

Parameters:

dim (int) – The dimension along which to normalize.

Methods

Attributes

forward(x)[source]

Apply L2 normalization to the input tensor.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_sz, seq_len, dim).

Returns:

Normalized tensor of shape (batch_sz, seq_len, dim).

Return type:

torch.Tensor