fl4health.datasets.rxrx1.load_data module¶
- construct_rxrx1_tensor_dataset(metadata, data_path, client_num, dataset_type, transform=None)[source]¶
Construct a TensorDataset for rxrx1 data.
- Parameters:
metadata (DataFrame) – A DataFrame containing image metadata.
data_path (Path) – Root directory which the image data should be loaded.
client_num (int) – Client number to load data for.
dataset_type (str) – ‘train’ or ‘test’ to specify dataset type.
transform (Callable | None) – Transformation function to apply to the images. Defaults to None.
- Returns:
A TensorDataset containing the processed images and label map.
- Return type:
tuple[TensorDataset, dict[int, int]]
- create_splits(dataset, seed=None, train_fraction=0.8)[source]¶
Splits the dataset into training and validation sets.
- Parameters:
dataset (Dataset) – The dataset to split.
train_fraction (float) – Fraction of data to use for training.
- Returns:
(train_dataset, val_dataset)
- Return type:
Tuple
- label_frequency(dataset, original_label_map)[source]¶
Prints the frequency of each label in the dataset.
- Parameters:
dataset (TensorDataset | Subset) – The dataset to analyze.
original_label_map (dict[int, int]) – A mapping of the original labels to their new labels.
- Return type: