fl4health.utils.dataset module

class BaseDataset(transform, target_transform)[source]

Bases: ABC, Dataset

__init__(transform, target_transform)[source]

Abstract base class for datasets used in this library. This class inherits from the torch Dataset base class.

Parameters:
  • transform (Callable | None, optional) – Optional transformation to be applied to the input data. NOTE: This transformation is applied at load time within __get_item__ Defaults to None.

  • target_transform (Callable | None, optional) – Optional transformation to be applied to the target data. NOTE: This transformation is applied at load time within __get_item__ Defaults to None.

update_target_transform(g)[source]
Return type:

None

update_transform(f)[source]
Return type:

None

class DictionaryDataset(data, targets)[source]

Bases: Dataset

__init__(data, targets)[source]

A torch dataset that supports a dictionary of input data rather than just a torch.Tensor. This kind of dataset is useful when dealing with non-trivial inputs to a model. For example, a language model may require token ids AND attention masks. This dataset supports that functionality.

Parameters:
  • data (dict[str, list[torch.Tensor]]) – A set of data for model training/input in the form of a dictionary of tensors.

  • targets (torch.Tensor) – Target tensor.

class SslTensorDataset(data, targets=None, transform=None, target_transform=None)[source]

Bases: TensorDataset

__init__(data, targets=None, transform=None, target_transform=None)[source]

Dataset specifically designed to perform self-supervised learning, where we don’t have a specific set of targets, because targets are derived from the data tensors.

Parameters:
  • data (torch.Tensor) – Tensor representing the input data for the dataset.

  • targets (torch.Tensor | None, optional) – REQUIRED TO BE NONE. The type and argument here is simply to maintain compatibility with our TensorDataset base. Defaults to None.

  • transform (Callable | None, optional) – Any transform to be applied to the data tensors. This transform is performed BEFORE and target transforms that produce the self-supervised targets from the data. NOTE: These transformations and the target_transform functions are applied AT LOAD TIME Defaults to None.

  • target_transform (Callable | None, optional) – Any transform to be applied to the data tensors to produce target tensors for training. This transform is performed after and transforms for the data tensors themselves to produce the self-supervised targets from the data. NOTE: These transformation functions are applied AT LOAD TIME. Defaults to None.

class SyntheticDataset(data, targets)[source]

Bases: TensorDataset

__init__(data, targets)[source]

A dataset for synthetically created data strictly in the form of pytorch tensors. Generally, this dataset is just used for tests.

Parameters:
  • data (torch.Tensor) – Data tensor with first dimension corresponding to the number of datapoints

  • targets (torch.Tensor) – Target tensor with first dimension corresponding to the number of datapoints

class TensorDataset(data, targets=None, transform=None, target_transform=None)[source]

Bases: BaseDataset

__init__(data, targets=None, transform=None, target_transform=None)[source]

Basic dataset where the data and targets are assumed to be torch tensors. Optionally, this class allows the user to perform transformations on both the data and the targets. This is useful, for example, in performing data augmentation, label blurring, etc.

Parameters:
  • data (torch.Tensor) – Input data for training.

  • targets (torch.Tensor | None, optional) – Target data for training. Defaults to None.

  • transform (Callable | None, optional) – Optional transformation to be applied to the input data. NOTE: This transformation is applied at load time within __get_item__ Defaults to None.

  • target_transform (Callable | None, optional) – Optional transformation to be applied to the target data. NOTE: This transformation is applied at load time within __get_item__ Defaults to None.

select_by_indices(dataset, selected_indices)[source]

This function is used to extract a subset of a dataset sliced by the indices in the tensor selected_indices. The dataset returned should be of the same type as the input but with only data associated with the given indices.

Parameters:
  • dataset (D) – Dataset to be “subsampled” using the provided indices.

  • selected_indices (torch.Tensor) – Indices within the datasets data and targets (if they exist) to select

Raises:

TypeError – Will throw an error if the dataset provided is not supported

Returns:

Dataset with only the data associated with the provided indices. Must be of a supported type.

Return type:

D