Source code for fl4health.datasets.rxrx1.preprocess

import argparse
import os
import pickle
from pathlib import Path
from typing import Any

import pandas as pd
import torch
from PIL import Image
from torchvision.transforms import ToTensor


[docs] def filter_and_save_data(metadata: pd.DataFrame, top_sirna_ids: list[int], cell_type: str, output_path: Path) -> None: """ Filters data for the given cell type and frequency of their sirna_id and saves it to a CSV file. Args: metadata (pd.DataFrame): Metadata containing information about all images. top_sirna_ids (list[int]): Top sirna_id values to filter by. cell_type (str): Cell type to filter by. output_path (Path): Path to save the filtered metadata. """ filtered_metadata = metadata[(metadata["sirna_id"].isin(top_sirna_ids)) & (metadata["cell_type"] == cell_type)] filtered_metadata.to_csv(output_path, index=False)
[docs] def load_image(row: dict[str, Any], root: Path) -> torch.Tensor: """ Load an image tensor for a given row of metadata. Args: row (dict[str, Any]): A row of metadata containing experiment, plate, well, and site information. root (Path): Root directory containing the image files. Returns: torch.Tensor: The loaded image tensor. """ experiment = row["experiment"] plate = row["plate"] well = row["well"] site = row["site"] images = [] # Rxrx1 originally consists of 6 channels, but to reduce the computational cost, we only use 3 channels # following previous works such as https://github.com/p-lambda/wildYe. for channel in range(1, 4): image_path = os.path.join(root, f"images/{experiment}/Plate{plate}/{well}_s{site}_w{channel}.png") if not Path(image_path).exists(): raise FileNotFoundError(f"Image not found at {image_path}") image = ToTensor()(Image.open(image_path).convert("L")) images.append(image) # Concatenate the three channels into one tensor return torch.cat(images, dim=0)
[docs] def process_data(metadata: pd.DataFrame, input_dir: Path, output_dir: Path, client_num: int, type_data: str) -> None: """ Process the entire dataset, loading image tensors for each row. Args: metadata (pd.DataFrame): Metadata containing information about all images. input_dir (Path): Input directory containing the image files. output_dir (Path): Output directory containing the image files. client_num (int): Client number to load data for. type_data (str): 'train' or 'test' to specify dataset type. """ for i, row in metadata.iterrows(): image_tensor = load_image(row.to_dict(), Path(input_dir)) save_to_pkl(image_tensor, os.path.join(output_dir, f"{type_data}_data_{client_num+1}", f"image_{i}.pkl"))
[docs] def save_to_pkl(data: torch.Tensor, output_path: str) -> None: """ Save data to a pickle file. Args: data (torch.Tensor): Data to save. output_path (str): Path to the output pickle file. """ with open(output_path, "wb") as f: pickle.dump(data, f)
[docs] def main(dataset_dir: Path) -> None: metadata_file = os.path.join(dataset_dir, "metadata.csv") output_dir = os.path.join(dataset_dir, "clients") os.makedirs(output_dir, exist_ok=True) data = pd.read_csv(metadata_file) # Get the top 50 `sirna_id`s by frequency top_sirna_ids = data["sirna_id"].value_counts().head(50).index.tolist() # Define cell types to distribute data based on them for each client cell_types = ["RPE", "HUVEC", "HEPG2", "U2OS"] output_files = [os.path.join(output_dir, f"meta_data_{i+1}.csv") for i in range(len(cell_types))] # Filter and save data for each client for cell_type, output_path in zip(cell_types, output_files): filter_and_save_data(data, top_sirna_ids, cell_type, Path(output_path)) for i, metadata_path in enumerate(output_files): metadata = pd.read_csv(metadata_path) # Split the metadata into train and test datasets train_metadata = metadata[metadata["dataset"] == "train"] test_metadata = metadata[metadata["dataset"] == "test"] process_data(train_metadata, dataset_dir, Path(output_dir), i, "train") process_data(test_metadata, dataset_dir, Path(output_dir), i, "test")
if __name__ == "__main__": # Argument parsing parser = argparse.ArgumentParser(description="Filter dataset by the most frequent sirna_id and cell_type.") parser.add_argument("dataset_dir", type=str, help="Path to the dataset directory containing metadata.csv") args = parser.parse_args() main(args.dataset_dir)