mmlearn.datasets.processors.transforms

Custom transforms for datasets/inputs.

Functions

repeat_interleave_batch(x, b, repeat)[source]

Repeat and interleave a tensor across the batch dimension.

Parameters:
  • x (torch.Tensor) – Input tensor to be repeated.

  • b (int) – Size of the batch to be repeated.

  • repeat (int) – Number of times to repeat each batch.

Returns:

The repeated tensor with shape adjusted for the batch.

Return type:

torch.Tensor

Classes

TrimText

Trim text strings as a preprocessing step before tokenization.

class TrimText(trim_size)[source]

Trim text strings as a preprocessing step before tokenization.

Parameters:

trim_size (int) – The maximum length of the trimmed text.

__call__(sentence)[source]

Trim the given sentence(s).

Parameters:

sentence (Union[str, list[str]]) – Sentence(s) to be trimmed.

Returns:

Trimmed sentence(s).

Return type:

Union[str, list[str]]

Raises:

TypeError – If the input sentence is not a string or list of strings.

repeat_interleave_batch(x, b, repeat)[source]

Repeat and interleave a tensor across the batch dimension.

Parameters:
  • x (torch.Tensor) – Input tensor to be repeated.

  • b (int) – Size of the batch to be repeated.

  • repeat (int) – Number of times to repeat each batch.

Returns:

The repeated tensor with shape adjusted for the batch.

Return type:

torch.Tensor