fl4health.losses.mkmmd_loss module

class MkMmdLoss(device, gammas=None, betas=None, minimize_type_two_error=True, normalize_features=False, layer_name=None, perform_linear_approximation=False)[source]

Bases: Module

__init__(device, gammas=None, betas=None, minimize_type_two_error=True, normalize_features=False, layer_name=None, perform_linear_approximation=False)[source]

Compute the multi-kernel maximum mean discrepancy (MK-MMD) between the source and target domains. Also allows for optimization of the coefficients, beta

Parameters:
  • device (torch.device) – Device onto which tensors should be moved

  • gammas (torch.Tensor | None, optional) – These are known as the length-scales of the RBF functions used to compute the Mk-MMD distances. The length of this list defines the number of kernels used in the norm measurement. If none, a default of 19 kernels is used. Defaults to None.

  • betas (torch.Tensor | None, optional) – These are the linear coefficients used on the basis of kernels to compute the Mk-MMD measure. If not provided, a unit-length, random default is constructed. These can be optimized using the functions of this class. Defaults to None.

  • minimize_type_two_error (bool | None, optional) – Whether we’re aiming to minimize the type II error in optimizing the betas or maximize it. The first coincides with trying to minimize feature distance. The second coincides with trying to maximize their feature distance. Defaults to True.

  • normalize_features (bool | None, optional) – Whether to normalize the features to have unit length before computing the MK-MMD and optimizing betas. Defaults to False.

  • layer_name (str | None, optional) – The name of the layer to extract features from. Defaults to None.

  • perform_linear_approximation (bool | None, optional) – Whether to use linear approximations for the estimates of the mean and covariance of the kernel values. Experimentally, we have found that the linear approximations largely hinder the statistical power of Mk-MMD. Defaults to False

beta_with_extreme_kernel_base_values(hat_d_per_kernel, hat_Q_k, minimize_type_two_error=True)[source]
Return type:

Tensor

compute_all_h_u_all_samples(X, Y)[source]
Return type:

Tensor

compute_all_h_u_from_inner_products(inner_product_all_samples)[source]
Return type:

Tensor

compute_all_h_u_from_inner_products_linear(inner_product_quadruples)[source]
Return type:

Tensor

compute_all_h_u_linear(X, Y)[source]
Return type:

Tensor

compute_euclidean_inner_products(X, Y)[source]
Return type:

Tensor

compute_euclidean_inner_products_linear(v_i_quadruples)[source]
Return type:

Tensor

compute_h_u_from_inner_products(inner_products, gamma)[source]
Return type:

Tensor

compute_h_u_from_inner_products_linear(inner_products, gamma)[source]
Return type:

Tensor

compute_hat_Q_k(all_h_u_per_sample, hat_d_per_kernel)[source]
Return type:

Tensor

compute_hat_Q_k_linear(all_h_u_per_v_i)[source]
Return type:

Tensor

compute_hat_d_per_kernel(all_h_u_per_sample)[source]
Return type:

Tensor

compute_mkmmd(X, Y, beta)[source]
Return type:

Tensor

compute_vertices(hat_d_per_kernel)[source]
Return type:

Tensor

construct_quadruples(X, Y)[source]

In this function, we assume that X, Y: n_samples, n_features are the same size. We construct the quadruples v_i = [x_{2i-1}, x_{2i}, y_{2i-1}, y_{2i}] forming a matrix of dimension n_samples/2, 4, n_features Note that if n_samples is not divisible by 2, we leave off the modulus

Return type:

Tensor

form_and_solve_qp(hat_d_per_kernel, regularized_Q_k)[source]
Return type:

Tensor

form_h_u_delta_w_i(all_h_u_per_v_i)[source]
Return type:

Tensor

form_kernel_samples_minus_expectation(all_h_u_per_sample, hat_d_per_kernel)[source]
Return type:

Tensor

forward(Xs, Xt)[source]

Compute the multi-kernel maximum mean discrepancy (MK-MMD) between the source and target domains.

Parameters:
  • Xs (torch.Tensor) – Source domain data, shape (n_samples, n_features)

  • Xt (torch.Tensor) – Target domain data, shape (n_samples, n_features)

Returns:

MK-MMD value

Return type:

torch.Tensor

get_best_vertex_for_objective_function(hat_d_per_kernel, hat_Q_k)[source]
Return type:

Tensor

normalize(X)[source]
Return type:

Tensor

optimize_betas(X, Y, lambda_m=1e-05)[source]
Return type:

Tensor