fl4health.datasets.rxrx1.load_data module

construct_rxrx1_tensor_dataset(metadata, data_path, client_num, dataset_type, transform=None)[source]

Construct a TensorDataset for rxrx1 data.

Parameters:
  • metadata (DataFrame) – A DataFrame containing image metadata.

  • data_path (Path) – Root directory which the image data should be loaded.

  • client_num (int) – Client number to load data for.

  • dataset_type (str) – ‘train’ or ‘test’ to specify dataset type.

  • transform (Callable | None) – Transformation function to apply to the images. Defaults to None.

Returns:

A TensorDataset containing the processed images and label map.

Return type:

tuple[TensorDataset, dict[int, int]]

create_splits(dataset, seed=None, train_fraction=0.8)[source]

Splits the dataset into training and validation sets.

Parameters:
  • dataset (Dataset) – The dataset to split.

  • train_fraction (float) – Fraction of data to use for training.

Returns:

(train_dataset, val_dataset)

Return type:

Tuple

label_frequency(dataset, original_label_map)[source]

Prints the frequency of each label in the dataset.

Parameters:
  • dataset (TensorDataset | Subset) – The dataset to analyze.

  • original_label_map (dict[int, int]) – A mapping of the original labels to their new labels.

Return type:

None

load_rxrx1_data(data_path, client_num, batch_size, seed=None, train_val_split=0.8, num_workers=0)[source]
Return type:

tuple[DataLoader, DataLoader, dict[str, int]]

load_rxrx1_test_data(data_path, client_num, batch_size, num_workers=0)[source]
Return type:

tuple[DataLoader, dict[str, int]]