fl4health.utils.dataset module

class BaseDataset(transform, target_transform)[source]

Bases: ABC, Dataset

update_target_transform(g)[source]
Return type:

None

update_transform(f)[source]
Return type:

None

class DictionaryDataset(data, targets)[source]

Bases: Dataset

__init__(data, targets)[source]

A torch dataset that supports a dictionary of input data rather than just a torch.Tensor. This kind of dataset is useful when dealing with non-trivial inputs to a model. For example, a language model may require token ids AND attention masks. This dataset supports that functionality.

Parameters:
  • data (dict[str, list[torch.Tensor]]) – A set of data for model training/input in the form of a dictionary of tensors.

  • targets (torch.Tensor) – Target tensor.

class SslTensorDataset(data, targets=None, transform=None, target_transform=None)[source]

Bases: TensorDataset

class SyntheticDataset(data, targets)[source]

Bases: TensorDataset

__init__(data, targets)[source]

A dataset for synthetically created data strictly in the form of pytorch tensors. Generally, this dataset is just used for tests. :type data: Tensor :param data: Data tensor with first dimension corresponding to the number of datapoints :type data: torch.Tensor :type targets: Tensor :param targets: Target tensor with first dimension corresponding to the number of datapoints :type targets: torch.Tensor

class TensorDataset(data, targets=None, transform=None, target_transform=None)[source]

Bases: BaseDataset

select_by_indices(dataset, selected_indices)[source]

This function is used to extract a subset of a dataset sliced by the indices in the tensor selected_indices. The dataset returned should be of the same type as the input but with only data associated with the given indices.

Parameters:
  • dataset (D) – Dataset to be “subsampled” using the provided indices.

  • selected_indices (torch.Tensor) – Indices within the datasets data and targets (if they exist) to select

Raises:

TypeError – Will throw an error if the dataset provided is not supported

Returns:

Dataset with only the data associated with the provided indices. Must be of a supported type.

Return type:

D