fl4health.losses.deep_mmd_loss module

class DeepMmdLoss(device, input_size, hidden_size=10, output_size=50, lr=0.001, is_unbiased=True, gaussian_degree=1, optimization_steps=5)[source]

Bases: Module

MMDu(features, len_s, features_org, sigma_q, sigma_phi, epsilon, is_smooth=True, is_var_computed=True)[source]

Compute value of deep-kernel MMD and std of deep-kernel MMD using merged data.

Parameters:
  • features (torch.Tensor) – The output features of the deep network.

  • len_s (int) – The length of the sample.

  • features_org (torch.Tensor) – The original input features of the deep network.

  • sigma_q (torch.Tensor) – The sigma_q parameter.

  • sigma_phi (torch.Tensor) – The sigma_phi parameter.

  • epsilon (torch.Tensor) – The epsilon parameter.

  • is_smooth (bool, optional) – Whether to use the smooth version of the MMD. Defaults to True.

  • is_var_computed (bool, optional) – Whether to compute the variance of the MMD. Defaults to True.

Returns:

The value of MMD and the variance of MMD

if required to compute.

Return type:

tuple[torch.Tensor, torch.Tensor | None]

__init__(device, input_size, hidden_size=10, output_size=50, lr=0.001, is_unbiased=True, gaussian_degree=1, optimization_steps=5)[source]

Compute the Deep MMD (Maximum Mean Discrepancy) loss, as proposed in the paper Learning Deep Kernels for Non-Parametric Two-Sample Tests. This loss function uses a kernel-based approach to assess whether two samples are drawn from the same distribution. By minimizing this loss, we can learn a deep kernel that reduces the MMD distance between two distributions, ensuring that the input feature representations are aligned. This implementation is inspired by the original code from the paper: https://github.com/fengliu90/DK-for-TST.

Parameters:
  • device (torch.device) – Device onto which tensors should be moved

  • input_size (int) – The length of the input feature representations of the deep network as the deep kernel used to compute the MMD loss.

  • hidden_size (int, optional) – The hidden size of the deep network as the deep kernel used to compute the MMD loss. Defaults to 10.

  • output_size (int, optional) – The output size of the deep network as the deep kernel used to compute the MMD loss. Defaults to 50.

  • lr (float, optional) – Learning rate for training the Deep Kernel. Defaults to 0.001.

  • is_unbiased (bool, optional) – Whether to use the unbiased estimator for the MMD loss. Defaults to True.

  • gaussian_degree (int, optional) – The degree of the generalized Gaussian kernel. Defaults to 1.

  • optimization_steps (int, optional) – The number of optimization steps to train the Deep Kernel in each forward pass. Defaults to 5.

compute_kernel(X, Y)[source]

Compute the Deep MMD Loss.

Parameters:
  • X (torch.Tensor) – The input tensor X.

  • Y (torch.Tensor) – The input tensor Y.

Returns:

The value of Deep MMD Loss.

Return type:

torch.Tensor

forward(Xs, Xt)[source]

Forward pass of the Deep MMD Loss where it first trains the deep kernel for number of optimization steps and then computes the MMD loss.

Parameters:
  • Xs (torch.Tensor) – The source input tensor.

  • Xt (torch.Tensor) – The target input tensor.

Returns:

The value of Deep MMD Loss.

Return type:

torch.Tensor

h1_mean_var_gram(k_x, k_y, k_xy, is_var_computed)[source]

Compute value of MMD and std of MMD using kernel matrix.

Parameters:
  • k_x (torch.Tensor) – The kernel matrix of x.

  • k_y (torch.Tensor) – The kernel matrix of y.

  • k_xy (torch.Tensor) – The kernel matrix of x and y.

  • is_var_computed (bool) – Whether to compute the variance of the MMD.

Returns:

The value of MMD and the variance of MMD

if required to compute.

Return type:

tuple[torch.Tensor, torch.Tensor | None]

pairwise_distance_squared(X, Y)[source]

Compute the paired distance between x and y.

Parameters:
  • X (torch.Tensor) – The input tensor X.

  • Y (torch.Tensor) – The input tensor Y.

Returns:

The paired distance between X and Y.

Return type:

torch.Tensor

train_kernel(X, Y)[source]

Train the Deep MMD kernel.

Parameters:
  • X (torch.Tensor) – The input tensor X.

  • Y (torch.Tensor) – The input tensor Y.

Return type:

None

class ModelLatentF(x_in_dim, hidden_dim, x_out_dim)[source]

Bases: Module

__init__(x_in_dim, hidden_dim, x_out_dim)[source]

Deep network for learning the deep kernel over features.

Parameters:
  • x_in_dim (int) – The input dimension of the deep network.

  • hidden_dim (int) – The hidden dimension of the deep network.

  • x_out_dim (int) – The output dimension of the deep network.

forward(input)[source]

Forward pass of the deep network.

Parameters:

input (torch.Tensor) – The input tensor to the deep network.

Returns:

The output tensor of the deep network.

Return type:

torch.Tensor