Source code for fl4health.reporting.wandb_reporter

from enum import Enum
from logging import WARNING
from pathlib import Path
from typing import Any

import wandb
import wandb.wandb_run
from flwr.common.logger import log

from fl4health.reporting.base_reporter import BaseReporter


[docs] class StepType(Enum): ROUND = "round" EPOCH = "epoch" STEP = "step"
# TODO: Add ability to parse data types and save certain data types in specific ways # (eg. Artifacts, Tables, etc.)
[docs] class WandBReporter(BaseReporter):
[docs] def __init__( self, wandb_step_type: StepType | str = StepType.ROUND, project: str | None = None, entity: str | None = None, config: dict | str | None = None, group: str | None = None, job_type: str | None = None, tags: list[str] | None = None, name: str | None = None, id: str | None = None, resume: str = "allow", **kwargs: Any, ) -> None: """ _summary_ Args: wandb_step_type (StepType | str, optional): Whether to use the 'round', 'epoch' or 'step' as the wandb_step value when logging information to the wandb server. project (str | None, optional): The name of the project where you're sending the new run. If unspecified, wandb will try to infer or set to "uncategorized" entity (str | None, optional): An entity is a username or team name where you're sending runs. This entity must exist before you can send runs there, so make sure to create your account or team in the UI before starting to log runs. If you don't specify an entity, the run will be sent to your default entity. Change your default entity in your settings under "default location to create new projects". config (str | None, optional): This sets wandb.config, a dictionary-like object for saving inputs to your job such as hyperparameters for a model. If dict: will load the key value pairs into the wandb.config object. If str: will look for a yaml file by that name, and load config from that file into the wandb.config object. group (str | None, optional): Specify a group to organize individual runs into a larger experiment. job_type (str | None, optional): Specify the type of run, useful when grouping runs. tags (list[str] |None, optional): A list of strings, which will populate the list of tags on this run. If you want to add tags to a resumed run without overwriting its existing tags, use run.tags += ["new_tag"] after wandb.init(). name (str | None, optional): A short display name for this run. Default generates a random two-word name. id (str | None, optional): A unique ID for this run. It must be unique in the project, and if you delete a run you can't reuse the ID. resume (str): Indicates how to handle the case when a run has the same entity, project and run id as a previous run. 'must' enforces the run must resume from the run with same id and throws an error if it does not exist. 'never' enforces that a run will not resume and throws an error if run id exists. 'allow' resumes if the run id already exists. Defaults to 'allow'. kwargs (Any): Keyword arguments to wandb.init excluding the ones explicitly described above. Documentation here: https://docs.wandb.ai/ref/python/init/ """ # Create wandb metadata dir if necessary if kwargs.get("dir") is not None: Path(kwargs["dir"]).mkdir(exist_ok=True) # Set attributes self.wandb_init_kwargs = kwargs self.wandb_step_type = StepType(wandb_step_type) self.run_started = False self.initialized = False self.project = project self.entity = entity self.config = config self.group = group self.job_type = job_type self.tags = tags self.name = name self.id = id self.resume = resume # Keep track of epoch and step. Initialize as 0. self.current_epoch = 0 self.current_step = 0 # Initialize run later to avoid creating runs while debugging self.run: wandb.wandb_run.Run
[docs] def initialize(self, **kwargs: Any) -> None: """Checks if an id was provided by the client or server. If an id was passed to the WandBReporter on init then it takes priority over the one passed by the client/server. """ if self.id is None: self.id = kwargs.get("id") self.initialized = True
[docs] def define_metrics(self) -> None: """This method defines some of the metrics we expect to see from Basic Client and server. Note that you do not have to define metrics, but it can be useful for determining what should and shouldn't go into the run summary. """ # Note that the hidden argument is not working. Raised issue here: https://github.com/wandb/wandb/issues/8890 # Round, epoch and step self.run.define_metric("fit_step", summary="none", hidden=True) # Current fit step self.run.define_metric("fit_epoch", summary="none", hidden=True) # Current fit epoch self.run.define_metric("round", summary="none", hidden=True) # Current server round self.run.define_metric("round_start", summary="none", hidden=True) self.run.define_metric("round_end", summary="none", hidden=True) # A server round contains a fit_round and maybe also an evaluate round self.run.define_metric("fit_round_start", summary="none", hidden=True) self.run.define_metric("fit_round_end", summary="none", hidden=True) self.run.define_metric("eval_round_start", summary="none", hidden=True) self.run.define_metric("eval_round_end", summary="none", hidden=True) # The metrics computed on all the samples from the final epoch, or the entire round if training by steps self.run.define_metric("fit_round_time_elapsed", summary="none") self.run.define_metric("eval_round_time_elapsed", summary="none") self.run.define_metric("fit_round_metrics", step_metric="round", summary="best") self.run.define_metric("eval_round_metrics", step_metric="round", summary="best") # Average of the losses for each step in the final epoch, or the entire round if training by steps. self.run.define_metric("fit_round_losses", step_metric="round", summary="best", goal="minimize") self.run.define_metric("eval_round_loss", step_metric="round", summary="best", goal="minimize") # The metrics computed at the end of the epoch on all the samples from the epoch self.run.define_metric("fit_round_metrics", step_metric="fit_epoch", summary="best") # Average of the losses for each step in the epoch self.run.define_metric("fit_epoch_losses", step_metric="fit_epoch", summary="best", goal="minimize") # The loss and metrics for each individual step self.run.define_metric("fit_step_metrics", step_metric="fit_step", summary="best") self.run.define_metric("fit_step_losses", step_metric="fit_step", summary="best", goal="minimize") # FlServer (Base Server) specific metrics self.run.define_metric("val - loss - aggregated", step_metric="round", summary="best", goal="minimize") self.run.define_metric("eval_round_metrics_aggregated", step_metric="round", summary="best") # The following metrics don't work with wandb since they are currently obtained after training instead of live self.run.define_metric("val - loss - centralized", step_metric="round", summary="best", goal="minimize") self.run.define_metric("eval_round_metrics_centralized", step_metric="round", summary="best")
[docs] def start_run(self, wandb_init_kwargs: dict[str, Any]) -> None: """Initializes the wandb run. We avoid doing this in the self.init function so that when debugging, jobs that fail before training starts do not get uploaded to wandb. Args: wandb_init_kwargs (dict[str, Any]): Keyword arguments for wandb.init() excluding the ones explicitly accessible through WandBReporter.init(). """ if not self.initialized: self.initialize() self.run = wandb.init( project=self.project, entity=self.entity, config=self.config, group=self.group, job_type=self.job_type, tags=self.tags, name=self.name, id=self.id, resume=self.resume, **wandb_init_kwargs, # Other less commonly used kwargs ) self.run_id = self.run._run_id # If run_id was None, we need to reset run id self.run_started = True # Wandb metric definitions self.define_metrics()
[docs] def report( self, data: dict[str, Any], round: int | None = None, epoch: int | None = None, step: int | None = None, ) -> None: """Reports wandb compatible data to the wandb server. Data passed to self.report is always reported. If round is None, the data is reported as config information. If round is specified, the data is logged to the wandb run at the current wandb step which is either the current round, epoch or step depending on the wandb_step_type passed on initialization. The current epoch and step are initialized at 0 and updated internally when specified as arguments to report. Therefore leaving epoch or step as None will overwrite the data for the previous epoch/step if the key is the same, otherwise the new key-value pairs are added. For example, if {"loss": value} is logged every epoch but wandb_step_type is 'round', then the value for "loss" at round 1 will be it's value at the last epoch of that round. You can only update or overwrite the current wandb step, previous steps can not be modified. Args: data (dict[str, Any]): Dictionary of wandb compatible data to log round (int | None, optional): The current FL round. If None, this indicates that the method was called outside of a round (e.g. for summary information). Defaults to None. epoch (int | None, optional): The current epoch (In total across all rounds). If None then this method was not called at or within the scope of an epoch. Defaults to None. step (int | None, optional): The current step (In total across all rounds and epochs). If None then this method was called outside the scope of a training or evaluation step (eg. at the end of an epoch or round) Defaults to None. """ # Now that report has been called we are finally forced to start the run. if not self.run_started: self.start_run(self.wandb_init_kwargs) # If round is None, assume data is summary information. if round is None: wandb.config.update(data) return # Update current epoch and step if they were specified if epoch is not None: if epoch < self.current_epoch: log( WARNING, f"The specified current epoch ({epoch}) is less than a previous \ current epoch ({self.current_epoch})", ) self.current_epoch = epoch if step is not None: if step < self.current_step: log( WARNING, f"The specified current step ({step}) is less than a previous current step ({self.current_step})", ) self.current_step = step # Log based on step type if self.wandb_step_type == StepType.ROUND: self.run.log(data, step=round) elif self.wandb_step_type == StepType.EPOCH: self.run.log(data, step=self.current_epoch) elif self.wandb_step_type == StepType.STEP: self.run.log(data, step=self.current_step)
[docs] def shutdown(self) -> None: self.run.finish()