importioimportosimportsignalimportsysimportwarningsfromcollections.abcimportCallable,SequencefromenumimportEnumfromimportlibimportreloadfromloggingimportDEBUG,INFO,WARN,LoggerfrommathimportceilfromtypingimportAny,no_type_checkimportnumpyasnpimporttorchfromflwr.common.loggerimportlogfromtorchimportnnfromtorch.nn.modules.lossimport_Lossfromtorch.optim.lr_schedulerimport_LRSchedulerfromtorch.utils.dataimportDataLoaderfromfl4health.utils.typingimportLogLevelwithwarnings.catch_warnings():# silences a bunch of deprecation warnings related to scipy.ndimage# Raised an issue with nnunet. https://github.com/MIC-DKFZ/nnUNet/issues/2370warnings.filterwarnings("ignore",category=DeprecationWarning)frombatchgenerators.dataloading.multi_threaded_augmenterimportMultiThreadedAugmenterfrombatchgenerators.dataloading.nondet_multi_threaded_augmenterimportNonDetMultiThreadedAugmenterfrombatchgenerators.dataloading.single_threaded_augmenterimportSingleThreadedAugmenterfromnnunetv2.training.dataloading.nnunet_datasetimportnnUNetDataset
[docs]classNnunetConfig(Enum):""" The possible nnunet model configs as of nnunetv2 version 2.5.1. See https://github.com/MIC-DKFZ/nnUNet/tree/v2.5.1 """_2D="2d"_3D_FULLRES="3d_fullres"_3D_CASCADE="3d_cascade_fullres"_3D_LOWRES="3d_lowres"
NNUNET_DEFAULT_NP={# Nnunet's default number of processes for each configNnunetConfig._2D:8,NnunetConfig._3D_FULLRES:4,NnunetConfig._3D_CASCADE:4,NnunetConfig._3D_LOWRES:8,}NNUNET_N_SPATIAL_DIMS={# The number of spatial dims for each configNnunetConfig._2D:2,NnunetConfig._3D_FULLRES:3,NnunetConfig._3D_CASCADE:3,NnunetConfig._3D_LOWRES:3,}
[docs]defuse_default_signal_handlers(fn:Callable)->Callable:""" This is a decorator that resets the ``SIGINT`` and ``SIGTERM`` signal handlers back to the python defaults for the execution of the method flwr 1.9.0 overrides the default signal handlers with handlers that raise an error on any interruption or termination. Since nnunet spawns child processes which inherit these handlers, when those subprocesses are terminated (which is expected behavior), the flwr signal handlers raise an error (which we don't want). Flwr is expected to fix this in the next release. See the following issue: https://github.com/adap/flower/issues/3837 """defnew_fn(*args:Any,**kwargs:Any)->Any:# Set SIGINT and SIGTERM back to defaults. Method returns previous handlersigint_old=signal.signal(signal.SIGINT,signal.default_int_handler)sigterm_old=signal.signal(signal.SIGTERM,signal.SIG_DFL)# Execute functionoutput=fn(*args,**kwargs)# Reset handlers back to what they were before function callsignal.signal(signal.SIGINT,sigint_old)signal.signal(signal.SIGTERM,sigterm_old)returnoutputreturnnew_fn
[docs]defreload_modules(packages:Sequence[str])->None:""" Given the names of one or more packages, subpackages or modules, reloads all the modules within the scope of each package or the modules themselves if a module was specified. Args: package (Sequence[str]): The absolute names of the packages, subpackages or modules to reload. The entire import hierarchy must be specified. Eg. ``package.subpackage`` to reload all modules in subpackage, not just ``subpackage``. Packages are reloaded in the order they are given. """form_name,moduleinlist(sys.modules.items()):forpackageinpackages:ifm_name.startswith(package):try:reload(module)exceptExceptionase:log(DEBUG,f"Failed to reload module {m_name}: {e}")
[docs]defset_nnunet_env(verbose:bool=False,**kwargs:str)->None:""" For each keyword argument name and value sets the current environment variable with the same name to that value and then reloads nnunet. Values must be strings. This is necessary because nnunet checks some environment variables on import, and therefore it must be imported or reloaded after they are set. """# Set environment variablesforkey,valinkwargs.items():os.environ[key]=valifverbose:log(INFO,f"Resetting env var '{key}' to '{val}'")# Its necessary to reload nnunetv2.paths first, then other modules with env varsreload_modules(["nnunetv2.paths"])reload_modules(["nnunetv2.default_n_proc_DA","nnunetv2.configuration"])# Reload whatever depends on nnunetv2 environment variables# Be careful. If you reload something with an enum in it, things get messed up.reload_modules(["nnunetv2","fl4health.clients.nnunet_client"])
# The two convert deepsupervision methods are necessary because fl4health requires# predictions, targets and inputs to be single torch.Tensors or Dicts of torch.Tensors
[docs]defconvert_deep_supervision_list_to_dict(tensor_list:list[torch.Tensor]|tuple[torch.Tensor],num_spatial_dims:int)->dict[str,torch.Tensor]:""" Converts a list of ``torch.Tensors`` to a dictionary. Names the keys for each tensor based on the spatial resolution of the tensor and its index in the list. Useful for nnUNet models with deep supervision where model outputs and targets loaded by the dataloader are lists. Assumes the spatial dimensions of the tensors are last. Args: tensor_list (list[torch.Tensor]): A list of tensors, usually either nnunet model outputs or targets, to be converted into a dictionary num_spatial_dims (int): The number of spatial dimensions. Assumes the spatial dimensions are last Returns: dict[str, torch.Tensor]: A dictionary containing the tensors as values where the keys are 'i-XxYxZ' where i was the tensor's index in the list and X,Y,Z are the spatial dimensions of the tensor """# Convert list of targets into a dictionarytensors={}fori,tensorinenumerate(tensor_list):# generate a key based on the spatial dimension and indexkey=str(i)+"-"+"x".join([str(s)forsintensor.shape[-num_spatial_dims:]])tensors[key]=tensorreturntensors
[docs]defconvert_deep_supervision_dict_to_list(tensor_dict:dict[str,torch.Tensor])->list[torch.Tensor]:""" Converts a dictionary of tensors back into a list so that it can be used by nnunet deep supervision loss functions Args: tensor_dict (dict[str, torch.Tensor]): Dictionary containing ``torch.Tensors``. The key values must start with 'X-' where X is an integer representing the index at which the tensor should be placed in the output list Returns: list[torch.Tensor]: A list of ``torch.Tensors`` """sorted_list=sorted(tensor_dict.items(),key=lambdax:int(x[0].split("-")[0]))return[tensorforkey,tensorinsorted_list]
[docs]defget_segs_from_probs(preds:torch.Tensor,has_regions:bool=False,threshold:float=0.5)->torch.Tensor:""" Converts the nnunet model output probabilities to predicted segmentations Args: preds (torch.Tensor): The one hot encoded model output probabilities with shape (batch, classes, \\*additional_dims). The background should be a separate class has_regions (bool, optional): If True, predicted segmentations can be multiple classes at once. The exception is the background class which is assumed to be the first class (class 0). If False, each value in predicted segmentations has only a single class. Defaults to False. threshold (float): When ``has_regions`` is True, this is the threshold value used to determine whether or not an output is a part of a class Returns: torch.Tensor: tensor containing the predicted segmentations as a one hot encoded binary tensor of 64-bit integers. """ifhas_regions:pred_segs=preds>threshold# Mask is the inverse of the background class. Ensures that values# classified as background are not part of another classmask=~pred_segs[:,0]returnpred_segs*maskelse:pred_segs=preds.argmax(1)[:,None]# shape (batch, 1, additional_dims)# one hot encode (OHE) predicted segmentations again# WARNING: Note the '_' after scatter. scatter_ and scatter are both# functions with different functionality. It is easy to introduce a bug# here by using the wrong onepred_segs_one_hot=torch.zeros(preds.shape,device=preds.device,dtype=torch.float32)pred_segs_one_hot.scatter_(1,pred_segs,1)# ohe -> One Hot Encoded# convert output preds to long since it is binaryreturnpred_segs_one_hot.long()
[docs]defcollapse_one_hot_tensor(input:torch.Tensor,dim:int=0)->torch.Tensor:""" Collapses a one hot encoded tensor so that they are no longer one hot encoded. Args: input (torch.Tensor): The binary one hot encoded tensor Returns: torch.Tensor: Integer tensor with the specified dim collapsed """returntorch.argmax(input.long(),dim=dim).to(input.device)
[docs]defget_dataset_n_voxels(source_plans:dict,n_cases:int)->float:""" Determines the total number of voxels in the dataset. Used by ``NnunetClient`` to determine the maximum batch size. Args: source_plans (Dict): The nnunet plans dict that is being modified n_cases (int): The number of cases in the dataset Returns: float: The total number of voxels in the local client dataset """# Need to determine input dimensionalityifNnunetConfig._3D_FULLRES.valueinsource_plans["configurations"]:cfg=source_plans["configurations"][NnunetConfig._3D_FULLRES.value]else:cfg=source_plans["configurations"][NnunetConfig._2D.value]# Get total number of voxels in datasetimage_shape=cfg["median_image_size_in_voxels"]approx_n_voxels=float(np.prod(image_shape,dtype=np.float64)*n_cases)returnapprox_n_voxels
[docs]defprepare_loss_arg(tensor:torch.Tensor|dict[str,torch.Tensor])->torch.Tensor|list[torch.Tensor]:""" Converts pred and target tensors into the proper data type to be passed to the nnunet loss functions. Args: tensor (torch.Tensor | dict[str, torch.Tensor]): The input tensor Returns: torch.Tensor | list[torch.Tensor]: The tensor ready to be passed to the loss function. A single tensor if not using deep supervision and a list of tensors if deep supervision is on. """# TODO: IDK why we have to make assumptions when we could just have a boolean stateifisinstance(tensor,torch.Tensor):returntensor# If input is a tensor then no changes requiredelifisinstance(tensor,dict):iflen(tensor)>1:# Assume deep supervision is on and return a listreturnconvert_deep_supervision_dict_to_list(tensor)else:# If dict has only one item, assume deep supervision is offreturnlist(tensor.values())[0]# return the torch.Tensor
[docs]def__init__(self,nnunet_augmenter:SingleThreadedAugmenter|NonDetMultiThreadedAugmenter|MultiThreadedAugmenter,nnunet_config:NnunetConfig|str,infinite:bool=False,set_len:int|None=None,ref_image_shape:Sequence|None=None,)->None:""" Wraps nnunet dataloader classes using the pytorch dataloader to make them pytorch compatible. Also handles some unique stuff specific to nnunet such as deep supervision and infinite dataloaders. The nnunet dataloaders should only be used for training and validation, not final testing. Args: nnunet_dataloader (SingleThreadedAugmenter | NonDetMultiThreadedAugmenter | MultiThreadedAugmenter): The dataloader used by nnunet nnunet_config (NnUNetConfig): The nnunet config. Enum type helps ensure that nnunet config is valid infinite (bool, optional): Whether or not to treat the dataset as infinite. The dataloaders sample data with replacement either way. The only difference is that if set to False, a ``StopIteration`` is generated after ``num_samples``/``batch_size`` steps. Defaults to False. set_len (int | None): If specified overrides the dataloaders estimate of its own length with the provided value. A ``StopIteration`` will be raised after ``set_len`` steps. If not specified the length is determined by scaling the number of samples by the ratio of image size to the networks input patch size. ref_image_shape (Sequence | None): The image shape to use when computing the scaling factor used in determining the length of the dataloader. Should be representative of the median or average image size in the data set. If not specified a random image is loaded and its shape is used in the calculation of the scaling factor. """# The augmenter is a wrapper on the nnunet dataloaderself.nnunet_augmenter=nnunet_augmenterifisinstance(self.nnunet_augmenter,SingleThreadedAugmenter):self.nnunet_dataloader=self.nnunet_augmenter.data_loaderelse:self.nnunet_dataloader=self.nnunet_augmenter.generator# Figure out if dataloader is 2d or 3dself.num_spatial_dims=NNUNET_N_SPATIAL_DIMS[NnunetConfig(nnunet_config)]# nnUNetDataloaders store their datasets under the self.data attributeself.dataset:nnUNetDataset=self.nnunet_dataloader._datasuper().__init__(dataset=self.dataset,batch_size=self.nnunet_dataloader.batch_size)# nnunet dataloaders are infinite by default so we have to track steps to stop iterationself.current_step=0self.infinite=infiniteself.ref_image_shape=ref_image_shapeself.set_len=set_len
def__next__(self)->tuple[torch.Tensor,torch.Tensor|dict[str,torch.Tensor]]:ifnotself.infiniteandself.current_step==self.__len__():self.reset()raiseStopIteration# Raise stop iteration after epoch has completedelse:self.current_step+=1batch=next(self.nnunet_augmenter)# This returns a dictionary# Note: When deep supervision is on, target is a list of ground truth# segmentations at various spatial scales/resolutions# nnUNet has a wrapper for loss functions to enable deep supervisioninputs:torch.Tensor=batch["data"]targets:torch.Tensor|list[torch.Tensor]=batch["target"]ifisinstance(targets,list):target_dict=convert_deep_supervision_list_to_dict(targets,self.num_spatial_dims)returninputs,target_dictelifisinstance(targets,torch.Tensor):returninputs,targetselse:raiseTypeError("Was expecting the target generated by the nnunet dataloader to be a list or a torch.Tensor")def__len__(self)->int:""" nnunetv2 v2.5.1 hardcodes an 'epoch' as 250 steps. We could set the len to ``n_samples``/``batch_size``, but this gets complicated as nnunet models operate on patches of the input images, and therefore can have batch sizes larger than the dataset. We would then have epochs with only 1 step! Here we go through the hassle of computing the ratio between the number of voxels in a sample and the number of voxels in a patch and then using that factor to scale ``n_samples``. This is particularly important for training 2d models on 3d data. """ifself.set_lenisnotNone:returnself.set_lenifself.ref_image_shapeisNone:# Sample will have shape (n_channels, x, y, ...)sample,_,_=self.dataset.load_case(self.nnunet_dataloader.indices[0])self.ref_image_shape=sample.shape[1:]# Must remove the channel dimensionn_image_voxels=np.prod(self.ref_image_shape)n_patch_voxels=np.prod(self.nnunet_dataloader.final_patch_size)# Scale factor is at least one to prevent shrinking the dataset. We can have a# larger patch size sometimes because nnunet will do paddingscale=max(n_image_voxels/n_patch_voxels,1)# Scale n_samples and then divide by batch size to get n_steps per epochreturnround((len(self.dataset)*scale)/self.nnunet_dataloader.batch_size)
def__iter__(self)->DataLoader:# type: ignore# mypy gets angry that the return type is differentreturnself
[docs]defshutdown(self)->None:""" The multithreaded augmenters used by nnunet need to be shutdown gracefully to avoid errors """ifisinstance(self.nnunet_augmenter,(NonDetMultiThreadedAugmenter,MultiThreadedAugmenter)):self.nnunet_augmenter._finish()else:delself.nnunet_augmenter
[docs]classModule2LossWrapper(_Loss):"""Converts a `` nn.Module`` subclass to a ``_Loss`` subclass"""def__init__(self,loss:nn.Module,**kwargs:Any)->None:super().__init__(**kwargs)self.loss=loss
[docs]def__init__(self,logger:Logger,level:LogLevel|int)->None:""" File-like stream object that redirects writes to a logger. Useful for redirecting stdout to a logger. Args: logger (Logger): The logger to redirect writes to level (LogLevel): The log level at which to redirect the writes """self.logger=loggerself.level=levelifisinstance(level,int)elselevel.valueself.linebuf=""# idk why this is needed. Got this class from stack overflow
[docs]def__init__(self,optimizer:torch.optim.Optimizer,initial_lr:float,max_steps:int,exponent:float=0.9,steps_per_lr:int=250,)->None:""" Learning rate (LR) scheduler with polynomial decay across fixed windows of size ``steps_per_lr``. Args: optimizer (Optimizer): The optimizer to apply LR scheduler to. initial_lr (float): The initial learning rate of the optimizer. max_steps (int): The maximum total number of steps across all FL rounds. exponent (float): Controls how quickly LR decreases over time. Higher values lead to more rapid descent. Defaults to 0.9. steps_per_lr (int): The number of steps per LR before decaying. (ie 10 means the LR will be constant for 10 steps prior to being decreased to the subsequent value). Defaults to 250 as that is the default for nnunet (decay LR once an epoch and epoch is 250 steps). """self.optimizer=optimizerself.initial_lr=initial_lrself.max_steps=max_stepsself.exponent=exponentself.steps_per_lr=steps_per_lr# Number of windows with constant LR across trainingself.num_windows=ceil(max_steps/self.steps_per_lr)self._step_count:intsuper().__init__(optimizer,-1,False)
[docs]@no_type_checkdefget_lr(self)->Sequence[float]:""" Get the current LR of the scheduler. Returns: Sequence[float]: A uniform sequence of LR for each of the parameter groups in the optimizer. """ifself._step_count-1==self.max_steps+1:log(WARN,f"Current LR step of {self._step_count} reached Max Steps of {self.max_steps}. LR will remain fixed.",)# Subtract 1 from step count since it starts at 1 (imposed by PyTorch)curr_step=min(self._step_count-1,self.max_steps)curr_window=int(curr_step/self.steps_per_lr)new_lr=self.initial_lr*(1-curr_window/self.num_windows)**self.exponentifcurr_step%self.steps_per_lr==0andcurr_step!=0andcurr_step!=self.max_steps:log(INFO,f"Decaying LR of optimizer to {new_lr} at step {curr_step}")return[new_lr]*len(self.optimizer.param_groups)