Bases: Module
PyTorch implementation of the LM-Supervised Retriever Loss.
Given input context x and ground truth continuation y, computes KL divergence
between retrieval likelihood P_R(d|x) and language model likelihood Q_LM(d|x,y),
where d is the retrieved document.
Shi, Weijia, et al. "Replug: Retrieval-augmented black-box language models."
arXiv preprint arXiv:2301.12652 (2023).
Arxiv: https://arxiv.org/pdf/2301.12652
Source code in src/fed_rag/loss/pytorch/lsr.py
| class LSRLoss(nn.Module):
"""PyTorch implementation of the LM-Supervised Retriever Loss.
Given input context x and ground truth continuation y, computes KL divergence
between retrieval likelihood P_R(d|x) and language model likelihood Q_LM(d|x,y),
where d is the retrieved document.
Source: Shi, Weijia, et al. "Replug: Retrieval-augmented black-box language models."
arXiv preprint arXiv:2301.12652 (2023).
Arxiv: https://arxiv.org/pdf/2301.12652
"""
def __init__(self, reduction: ReductionMode = ReductionMode.MEAN):
# This line is critical - it initializes all the Module machinery
super(LSRLoss, self).__init__()
if reduction not in ReductionMode.members_list():
msg = (
f"Invalid reduction {reduction}. "
f"Valid reductions are: {', '.join(ReductionMode.members_list())}"
)
raise InvalidReductionParam(msg)
self.reduction = reduction
def forward(
self, retrieval_scores: torch.Tensor, lm_scores: torch.Tensor
) -> torch.Tensor:
retrieval_log_probs = F.log_softmax(retrieval_scores, dim=1)
lm_probs = F.softmax(lm_scores, dim=1)
kl_div = F.kl_div(retrieval_log_probs, lm_probs, reduction="none").sum(
dim=-1
)
match self.reduction:
case ReductionMode.MEAN:
return kl_div.mean()
case ReductionMode.SUM:
return kl_div.sum()
case _: # pragma: no cover
assert_never(self.reduction) # pragma: no cover
|