import numpy as np
import torch
[docs]
class ModelLatentF(torch.nn.Module):
[docs]
def __init__(self, x_in_dim: int, hidden_dim: int, x_out_dim: int):
"""
Deep network for learning the deep kernel over features.
Args:
x_in_dim (int): The input dimension of the deep network.
hidden_dim (int): The hidden dimension of the deep network.
x_out_dim (int): The output dimension of the deep network.
"""
super().__init__()
self.latent = torch.nn.Sequential(
torch.nn.Linear(x_in_dim, hidden_dim, bias=True),
torch.nn.Softplus(),
torch.nn.Linear(hidden_dim, hidden_dim, bias=True),
torch.nn.Softplus(),
torch.nn.Linear(hidden_dim, hidden_dim, bias=True),
torch.nn.Softplus(),
torch.nn.Linear(hidden_dim, x_out_dim, bias=True),
)
[docs]
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the deep network.
Args:
input (torch.Tensor): The input tensor to the deep network.
Returns:
torch.Tensor: The output tensor of the deep network.
"""
feature_latent_map = self.latent(input)
return feature_latent_map
[docs]
class DeepMmdLoss(torch.nn.Module):
[docs]
def __init__(
self,
device: torch.device,
input_size: int,
hidden_size: int = 10,
output_size: int = 50,
lr: float = 0.001,
is_unbiased: bool = True,
gaussian_degree: int = 1,
optimization_steps: int = 5,
) -> None:
"""
Compute the Deep MMD (Maximum Mean Discrepancy) loss, as proposed in the paper Learning Deep Kernels for
Non-Parametric Two-Sample Tests. This loss function uses a kernel-based approach to assess whether two
samples are drawn from the same distribution. By minimizing this loss, we can learn a deep kernel that
reduces the MMD distance between two distributions, ensuring that the input feature representations are
aligned. This implementation is inspired by the original code from the paper:
https://github.com/fengliu90/DK-for-TST.
Args:
device (torch.device): Device onto which tensors should be moved
input_size (int): The length of the input feature representations of the deep network as the deep
kernel used to compute the MMD loss.
hidden_size (int, optional): The hidden size of the deep network as the deep kernel used to compute
the MMD loss. Defaults to 10.
output_size (int, optional): The output size of the deep network as the deep kernel used to compute
the MMD loss. Defaults to 50.
lr (float, optional): Learning rate for training the Deep Kernel. Defaults to 0.001.
is_unbiased (bool, optional): Whether to use the unbiased estimator for the MMD loss. Defaults to True.
gaussian_degree (int, optional): The degree of the generalized Gaussian kernel. Defaults to 1.
optimization_steps (int, optional): The number of optimization steps to train the Deep Kernel in each
forward pass. Defaults to 5.
"""
super().__init__()
self.device = device
self.lr = lr
self.is_unbiased = is_unbiased
self.gaussian_degree = gaussian_degree # generalized Gaussian (if L>1)
self.optimization_steps = optimization_steps
# Initialize the model
self.featurizer = ModelLatentF(input_size, hidden_size, output_size).to(self.device)
# Set the model to evaluation mode as default
self.featurizer.eval()
# Initialize parameters
self.epsilon_opt: torch.Tensor = torch.log(torch.from_numpy(np.random.rand(1) * 10 ** (-10)).to(self.device))
self.epsilon_opt.requires_grad = False
self.sigma_q_opt: torch.Tensor = torch.sqrt(torch.tensor(2 * 32 * 32, dtype=torch.float).to(self.device))
self.sigma_q_opt.requires_grad = False
self.sigma_phi_opt: torch.Tensor = torch.sqrt(torch.tensor(0.005, dtype=torch.float).to(self.device))
self.sigma_phi_opt.requires_grad = False
# Initialize optimizers
self.optimizer_F = torch.optim.AdamW(
list(self.featurizer.parameters()) + [self.epsilon_opt] + [self.sigma_q_opt] + [self.sigma_phi_opt],
lr=self.lr,
)
# Set the model to training mode if required to train the Deep Kernel
self.training = False
[docs]
def pairwise_distance_squared(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
"""
Compute the paired distance between x and y.
Args:
X (torch.Tensor): The input tensor X.
Y (torch.Tensor): The input tensor Y.
Returns:
torch.Tensor: The paired distance between X and Y.
"""
x_norm = (X**2).sum(1).view(-1, 1)
y_norm = (Y**2).sum(1).view(1, -1)
paired_distance = x_norm + y_norm - 2.0 * torch.mm(X, torch.transpose(Y, 0, 1))
paired_distance[paired_distance < 0] = 0
return paired_distance
[docs]
def h1_mean_var_gram(
self,
k_x: torch.Tensor,
k_y: torch.Tensor,
k_xy: torch.Tensor,
is_var_computed: bool,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Compute value of MMD and std of MMD using kernel matrix.
Args:
k_x (torch.Tensor): The kernel matrix of x.
k_y (torch.Tensor): The kernel matrix of y.
k_xy (torch.Tensor): The kernel matrix of x and y.
is_var_computed (bool): Whether to compute the variance of the MMD.
Returns:
tuple[torch.Tensor, torch.Tensor | None]: The value of MMD and the variance of MMD
if required to compute.
"""
nx = k_x.shape[0]
ny = k_y.shape[0]
if self.is_unbiased:
# compute the unbiased MMD estimator (\hat{\text{MMD}}_u^2) defined in Eq. (2) of the paper
xx = torch.div((torch.sum(k_x) - torch.sum(torch.diag(k_x))), (nx * (nx - 1)))
yy = torch.div((torch.sum(k_y) - torch.sum(torch.diag(k_y))), (ny * (ny - 1)))
xy = torch.div((torch.sum(k_xy) - torch.sum(torch.diag(k_xy))), (nx * (ny - 1)))
else:
# compute the biased MMD estimator (\hat{\text{MMD}}_b^2) defined below Equation (2) of the paper
xx = torch.div((torch.sum(k_x)), (nx * nx))
yy = torch.div((torch.sum(k_y)), (ny * ny))
xy = torch.div((torch.sum(k_xy)), (nx * ny))
mmd2 = xx - 2 * xy + yy
if not is_var_computed:
return mmd2, None
h_ij = k_x + k_y - k_xy - k_xy.transpose(0, 1)
# Compute the variance estimate of MMD defined in Equation (5) of the paper
v1 = (4.0 / ny**3) * (torch.dot(h_ij.sum(1), h_ij.sum(1)))
v2 = (4.0 / nx**4) * (h_ij.sum() ** 2)
variance_estimate = v1 - v2 + (10 ** (-8))
return mmd2, variance_estimate
[docs]
def MMDu(
self,
features: torch.Tensor,
len_s: int,
features_org: torch.Tensor,
sigma_q: torch.Tensor,
sigma_phi: torch.Tensor,
epsilon: torch.Tensor,
is_smooth: bool = True,
is_var_computed: bool = True,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Compute value of deep-kernel MMD and std of deep-kernel MMD using merged data.
Args:
features (torch.Tensor): The output features of the deep network.
len_s (int): The length of the sample.
features_org (torch.Tensor): The original input features of the deep network.
sigma_q (torch.Tensor): The sigma_q parameter.
sigma_phi (torch.Tensor): The sigma_phi parameter.
epsilon (torch.Tensor): The epsilon parameter.
is_smooth (bool, optional): Whether to use the smooth version of the MMD.
Defaults to True.
is_var_computed (bool, optional): Whether to compute the variance of the MMD.
Defaults to True.
Returns:
tuple[torch.Tensor, torch.Tensor | None]: The value of MMD and the variance of MMD
if required to compute.
"""
x = features[0:len_s, :] # fetch the sample 1 (features of deep networks)
y = features[len_s:, :] # fetch the sample 2 (features of deep networks)
distance_xx = self.pairwise_distance_squared(x, x)
distance_yy = self.pairwise_distance_squared(y, y)
distance_xy = self.pairwise_distance_squared(x, y)
if is_smooth:
x_original = features_org[0:len_s, :] # fetch the original sample 1
y_original = features_org[len_s:, :] # fetch the original sample 2
distance_xx_original = self.pairwise_distance_squared(x_original, x_original)
distance_yy_original = self.pairwise_distance_squared(y_original, y_original)
distance_xy_original = self.pairwise_distance_squared(x_original, y_original)
kernel_x = (1 - epsilon) * torch.exp(
-((distance_xx / sigma_phi) ** self.gaussian_degree) - distance_xx_original / sigma_q
) + epsilon * torch.exp(-distance_xx_original / sigma_q)
kernel_y = (1 - epsilon) * torch.exp(
-((distance_yy / sigma_phi) ** self.gaussian_degree) - distance_yy_original / sigma_q
) + epsilon * torch.exp(-distance_yy_original / sigma_q)
kernel_xy = (1 - epsilon) * torch.exp(
-((distance_xy / sigma_phi) ** self.gaussian_degree) - distance_xy_original / sigma_q
) + epsilon * torch.exp(-distance_xy_original / sigma_q)
else:
kernel_x = torch.exp(-distance_xx / sigma_phi)
kernel_y = torch.exp(-distance_yy / sigma_phi)
kernel_xy = torch.exp(-distance_xy / sigma_phi)
# kernel_x represents k_w(x_i, x_j), kernel_y represents k_w(y_i, y_j), kernel_xy represents
# k_w(x_i, y_j) for all i, j in the sample X and sample Y defined in Equation (1) of the paper
return self.h1_mean_var_gram(kernel_x, kernel_y, kernel_xy, is_var_computed)
[docs]
def train_kernel(self, X: torch.Tensor, Y: torch.Tensor) -> None:
"""
Train the Deep MMD kernel.
Args:
X (torch.Tensor): The input tensor X.
Y (torch.Tensor): The input tensor Y.
"""
self.featurizer.train()
self.sigma_q_opt.requires_grad = True
self.sigma_phi_opt.requires_grad = True
self.epsilon_opt.requires_grad = True
# Shuffle the data to ensure they are not always presented in the same order for training
# which might lead to overfitting
indices = torch.randperm(Y.size(0))
Y_shuffled = Y[indices]
features = torch.cat([X, Y_shuffled], 0)
# ------------------------------
# Train deep network for MMD-D
# ------------------------------
# Zero gradients
self.optimizer_F.zero_grad()
# Compute output of deep network
model_output = self.featurizer(features)
# Compute epsilon, sigma_q and sigma_phi in \kappa_w(x, y) in Equation (1) of the paper
epsilon = torch.exp(self.epsilon_opt) / (1 + torch.exp(self.epsilon_opt))
sigma_q = self.sigma_q_opt**2
sigma_phi = self.sigma_phi_opt**2
# Compute Deep MMD value and variance estimates
mmd_value_estimate, mmd_var_estimate = self.MMDu(
features=model_output,
len_s=X.shape[0],
features_org=features.view(features.shape[0], -1),
sigma_q=sigma_q,
sigma_phi=sigma_phi,
epsilon=epsilon,
is_var_computed=True,
)
if mmd_var_estimate is None:
raise AssertionError("Error: Variance of MMD is not computed. Please set is_var_computed=True.")
mmd_std_estimate = torch.sqrt(mmd_var_estimate)
# Forming \hat{J}_{\lambda} defined in Equation (4) of the paper (STAT_u)
stat_u = torch.div(-1 * mmd_value_estimate, mmd_std_estimate)
# Compute gradient
stat_u.backward()
# Update weights using gradient descent
self.optimizer_F.step()
[docs]
def compute_kernel(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
"""
Compute the Deep MMD Loss.
Args:
X (torch.Tensor): The input tensor X.
Y (torch.Tensor): The input tensor Y.
Returns:
torch.Tensor: The value of Deep MMD Loss.
"""
self.featurizer.eval()
self.sigma_q_opt.requires_grad = False
self.sigma_phi_opt.requires_grad = False
self.epsilon_opt.requires_grad = False
features = torch.cat([X, Y], 0)
# Compute output of deep network
model_output = self.featurizer(features)
# Compute epsilon, sigma_q and sigma_phi in \kappa_w(x, y) in Equation (1) of the paper
epsilon = torch.exp(self.epsilon_opt) / (1 + torch.exp(self.epsilon_opt))
sigma_q = self.sigma_q_opt**2
sigma_phi = self.sigma_phi_opt**2
# Compute Deep MMD value estimates
mmd_value_estimate, _ = self.MMDu(
features=model_output,
len_s=X.shape[0],
features_org=features.view(features.shape[0], -1),
sigma_q=sigma_q,
sigma_phi=sigma_phi,
epsilon=epsilon,
is_var_computed=False,
)
return mmd_value_estimate
[docs]
def forward(self, Xs: torch.Tensor, Xt: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the Deep MMD Loss where it first trains the deep kernel for number of optimization
steps and then computes the MMD loss.
Args:
Xs (torch.Tensor): The source input tensor.
Xt (torch.Tensor): The target input tensor.
Returns:
torch.Tensor: The value of Deep MMD Loss.
"""
if self.training:
for _ in range(self.optimization_steps):
self.train_kernel(Xs.clone().detach(), Xt.clone().detach())
return self.compute_kernel(Xs, Xt)