fl4health.utils.dataset module¶
- class BaseDataset(transform, target_transform)[source]¶
Bases:
ABC
,Dataset
- __init__(transform, target_transform)[source]¶
Abstract base class for datasets used in this library. This class inherits from the torch Dataset base class.
- Parameters:
transform (Callable | None, optional) – Optional transformation to be applied to the input data. NOTE: This transformation is applied at load time within
__get_item__
Defaults to None.target_transform (Callable | None, optional) – Optional transformation to be applied to the target data. NOTE: This transformation is applied at load time within
__get_item__
Defaults to None.
- class DictionaryDataset(data, targets)[source]¶
Bases:
Dataset
- __init__(data, targets)[source]¶
A torch dataset that supports a dictionary of input data rather than just a
torch.Tensor
. This kind of dataset is useful when dealing with non-trivial inputs to a model. For example, a language model may require token ids AND attention masks. This dataset supports that functionality.
- class SslTensorDataset(data, targets=None, transform=None, target_transform=None)[source]¶
Bases:
TensorDataset
- __init__(data, targets=None, transform=None, target_transform=None)[source]¶
Dataset specifically designed to perform self-supervised learning, where we don’t have a specific set of targets, because targets are derived from the data tensors.
- Parameters:
data (torch.Tensor) – Tensor representing the input data for the dataset.
targets (torch.Tensor | None, optional) – REQUIRED TO BE NONE. The type and argument here is simply to maintain compatibility with our TensorDataset base. Defaults to None.
transform (Callable | None, optional) – Any transform to be applied to the data tensors. This transform is performed BEFORE and target transforms that produce the self-supervised targets from the data. NOTE: These transformations and the
target_transform
functions are applied AT LOAD TIME Defaults to None.target_transform (Callable | None, optional) – Any transform to be applied to the data tensors to produce target tensors for training. This transform is performed after and transforms for the data tensors themselves to produce the self-supervised targets from the data. NOTE: These transformation functions are applied AT LOAD TIME. Defaults to None.
- class SyntheticDataset(data, targets)[source]¶
Bases:
TensorDataset
- __init__(data, targets)[source]¶
A dataset for synthetically created data strictly in the form of pytorch tensors. Generally, this dataset is just used for tests.
- Parameters:
data (torch.Tensor) – Data tensor with first dimension corresponding to the number of datapoints
targets (torch.Tensor) – Target tensor with first dimension corresponding to the number of datapoints
- class TensorDataset(data, targets=None, transform=None, target_transform=None)[source]¶
Bases:
BaseDataset
- __init__(data, targets=None, transform=None, target_transform=None)[source]¶
Basic dataset where the data and targets are assumed to be torch tensors. Optionally, this class allows the user to perform transformations on both the data and the targets. This is useful, for example, in performing data augmentation, label blurring, etc.
- Parameters:
data (torch.Tensor) – Input data for training.
targets (torch.Tensor | None, optional) – Target data for training. Defaults to None.
transform (Callable | None, optional) – Optional transformation to be applied to the input data. NOTE: This transformation is applied at load time within
__get_item__
Defaults to None.target_transform (Callable | None, optional) – Optional transformation to be applied to the target data. NOTE: This transformation is applied at load time within
__get_item__
Defaults to None.
- select_by_indices(dataset, selected_indices)[source]¶
This function is used to extract a subset of a dataset sliced by the indices in the tensor
selected_indices
. The dataset returned should be of the same type as the input but with only data associated with the given indices.- Parameters:
dataset (D) – Dataset to be “subsampled” using the provided indices.
selected_indices (torch.Tensor) – Indices within the datasets data and targets (if they exist) to select
- Raises:
TypeError – Will throw an error if the dataset provided is not supported
- Returns:
Dataset with only the data associated with the provided indices. Must be of a supported type.
- Return type:
D