fl4health.datasets.rxrx1.load_data module¶
- construct_rxrx1_tensor_dataset(metadata, data_path, client_num, dataset_type, transform=None)[source]¶
Construct a
TensorDataset
for rxrx1 data (https://www.rxrx.ai/rxrx1)- Parameters:
metadata (DataFrame) – A
DataFrame
containing image metadata.data_path (Path) – Root directory which the image data should be loaded.
client_num (int) – Client number to load data for.
dataset_type (str) – “train” or “test” to specify dataset type.
transform (Callable | None) – Transformation function to apply to the images. Defaults to None.
- Returns:
A
TensorDataset
containing the processed images and label map.- Return type:
tuple[TensorDataset, dict[int, int]]
- create_splits(dataset, seed=None, train_fraction=0.8)[source]¶
Splits the dataset into training and validation sets.
- Parameters:
dataset (TensorDataset) – The dataset to split.
seed (int | None, optional) – Seed meant to fix the sampling process associated with splitting. Defaults to None.
train_fraction (float, optional) – Fraction of data to use for training. Defaults to 0.8.
- Returns:
Indices associated with the selected datapoints for the train and validation sets
- Return type:
- label_frequency(dataset, original_label_map)[source]¶
Prints the frequency of each label in the dataset.
- Parameters:
dataset (TensorDataset | Subset) – The dataset to analyze.
original_label_map (dict[int, int]) – A mapping of the original labels to their new labels.
- Return type:
- load_rxrx1_data(data_path, client_num, batch_size, seed=None, train_val_split=0.8, num_workers=0)[source]¶
Load and split the data into training and validation dataloaders.
- Parameters:
data_path (Path) – Path to the full set of data
client_num (int) – Client number for the data you want to load
batch_size (int) – batch size for the data loaders
seed (int | None, optional) – Seed to fix randomness associated with data splitting. Defaults to None.
train_val_split (float, optional) – Percentage of data to put in the training loader. The remainder flow to the validation dataloader. Defaults to 0.8.
num_workers (int, optional) – Number of threads to be used by the dataloaders. Defaults to 0.
- Returns:
Train and validation dataloaders and a dictionary holding the size of each dataset.
- Return type: