fl4health.losses.deep_mmd_loss module¶
- class DeepMmdLoss(device, input_size, hidden_size=10, output_size=50, lr=0.001, is_unbiased=True, gaussian_degree=1, optimization_steps=5)[source]¶
Bases:
Module
- MMDu(features, len_s, features_org, sigma_q, sigma_phi, epsilon, is_smooth=True, is_var_computed=True)[source]¶
Compute value of deep-kernel MMD and std of deep-kernel MMD using merged data.
- Parameters:
features (torch.Tensor) – The output features of the deep network.
len_s (int) – The length of the sample.
features_org (torch.Tensor) – The original input features of the deep network.
sigma_q (torch.Tensor) – The sigma_q parameter.
sigma_phi (torch.Tensor) – The sigma_phi parameter.
epsilon (torch.Tensor) – The epsilon parameter.
is_smooth (bool, optional) – Whether to use the smooth version of the MMD. Defaults to True.
is_var_computed (bool, optional) – Whether to compute the variance of the MMD. Defaults to True.
- Returns:
- The value of MMD and the variance of MMD
if required to compute.
- Return type:
tuple[torch.Tensor, torch.Tensor | None]
- __init__(device, input_size, hidden_size=10, output_size=50, lr=0.001, is_unbiased=True, gaussian_degree=1, optimization_steps=5)[source]¶
Compute the Deep MMD (Maximum Mean Discrepancy) loss, as proposed in the paper Learning Deep Kernels for Non-Parametric Two-Sample Tests. This loss function uses a kernel-based approach to assess whether two samples are drawn from the same distribution. By minimizing this loss, we can learn a deep kernel that reduces the MMD distance between two distributions, ensuring that the input feature representations are aligned. This implementation is inspired by the original code from the paper: https://github.com/fengliu90/DK-for-TST.
- Parameters:
device (torch.device) – Device onto which tensors should be moved
input_size (int) – The length of the input feature representations of the deep network as the deep kernel used to compute the MMD loss.
hidden_size (int, optional) – The hidden size of the deep network as the deep kernel used to compute the MMD loss. Defaults to 10.
output_size (int, optional) – The output size of the deep network as the deep kernel used to compute the MMD loss. Defaults to 50.
lr (float, optional) – Learning rate for training the Deep Kernel. Defaults to 0.001.
is_unbiased (bool, optional) – Whether to use the unbiased estimator for the MMD loss. Defaults to True.
gaussian_degree (int, optional) – The degree of the generalized Gaussian kernel. Defaults to 1.
optimization_steps (int, optional) – The number of optimization steps to train the Deep Kernel in each forward pass. Defaults to 5.
- compute_kernel(X, Y)[source]¶
Compute the Deep MMD Loss.
- Parameters:
X (torch.Tensor) – The input tensor X.
Y (torch.Tensor) – The input tensor Y.
- Returns:
The value of Deep MMD Loss.
- Return type:
torch.Tensor
- forward(Xs, Xt)[source]¶
Forward pass of the Deep MMD Loss where it first trains the deep kernel for number of optimization steps and then computes the MMD loss.
- Parameters:
Xs (torch.Tensor) – The source input tensor.
Xt (torch.Tensor) – The target input tensor.
- Returns:
The value of Deep MMD Loss.
- Return type:
torch.Tensor
- h1_mean_var_gram(k_x, k_y, k_xy, is_var_computed)[source]¶
Compute value of MMD and std of MMD using kernel matrix.
- Parameters:
k_x (torch.Tensor) – The kernel matrix of x.
k_y (torch.Tensor) – The kernel matrix of y.
k_xy (torch.Tensor) – The kernel matrix of x and y.
is_var_computed (bool) – Whether to compute the variance of the MMD.
- Returns:
- The value of MMD and the variance of MMD
if required to compute.
- Return type:
tuple[torch.Tensor, torch.Tensor | None]